""" 认证相关 API 路由:注册、登录、刷新令牌、登出 """ import uuid import io import os import logging import traceback import secrets from datetime import datetime, timezone from pathlib import Path from typing import Optional from fastapi import APIRouter, Body, Depends, File, HTTPException, UploadFile, status from fastapi.security import OAuth2PasswordRequestForm from fastapi.responses import FileResponse from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from PIL import Image logger = logging.getLogger(__name__) from database import get_async_db from database.models import User, RefreshToken from services.auth_service import ( hash_password, verify_password, create_access_token, create_refresh_token, get_token_expires_at ) from services.sms_service import send_verification_code, verify_code from middleware.auth import get_current_user router = APIRouter(prefix="/api/auth", tags=["auth"]) # 请求模型 class RegisterRequest(BaseModel): """用户注册请求""" phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") password: str = Field(..., min_length=6, description="密码(至少6位)") nickname: str = Field(..., min_length=1, max_length=50, description="昵称") email: Optional[str] = Field(None, description="邮箱(可选)") agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") class LoginRequest(BaseModel): """用户登录请求""" phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") password: str = Field(..., min_length=1, description="密码") agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") class RefreshTokenRequest(BaseModel): """刷新令牌请求""" refresh_token: str = Field(..., description="刷新令牌") class SmsRequest(BaseModel): """发送短信验证码请求""" phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") purpose: str = Field(..., description="用途:register/login/reset_password/change_phone") class SmsLoginRequest(BaseModel): """验证码登录/注册请求(统一接口)""" phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") nickname: Optional[str] = Field(None, max_length=50, description="昵称(注册时必填,登录时可选)") class SmsRegisterRequest(BaseModel): """验证码注册请求""" phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") password: str = Field(..., min_length=6, description="密码(至少6位)") nickname: str = Field(..., min_length=1, max_length=50, description="昵称") email: Optional[str] = Field(None, description="邮箱(可选)") agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") class ResetPasswordRequest(BaseModel): """重置密码请求""" phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") new_password: str = Field(..., min_length=6, description="新密码(至少6位)") class ChangePasswordRequest(BaseModel): """修改密码请求""" old_password: str = Field(..., min_length=1, description="旧密码") new_password: str = Field(..., min_length=6, description="新密码(至少6位)") class ChangePhoneRequest(BaseModel): """修改手机号请求""" new_phone: str = Field(..., min_length=11, max_length=11, description="新手机号(11位)") code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") class UpdateNicknameRequest(BaseModel): """更新昵称请求""" nickname: str = Field(..., min_length=1, max_length=50, description="昵称(1-50个字符)") # 响应模型 class TokenResponse(BaseModel): """令牌响应""" access_token: str refresh_token: str token_type: str = "bearer" class UserResponse(BaseModel): """用户信息响应""" id: str phone: str email: Optional[str] nickname: str avatar_url: Optional[str] subscription_type: str created_at: str @router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) async def register( request: RegisterRequest, db: AsyncSession = Depends(get_async_db) ): """ 用户注册 注册成功后返回访问令牌和刷新令牌 """ # 验证是否同意用户协议和隐私政策 if not request.agreed_to_terms: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="请先阅读并同意用户协议和隐私政策" ) # 检查手机号是否已存在 stmt = select(User).where(User.phone == request.phone) result = await db.execute(stmt) existing_user = result.scalar_one_or_none() if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该手机号已被注册" ) # 检查邮箱是否已存在(如果提供了邮箱) if request.email: stmt = select(User).where(User.email == request.email) result = await db.execute(stmt) existing_email = result.scalar_one_or_none() if existing_email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该邮箱已被注册" ) # 创建新用户 user_id = str(uuid.uuid4()) password_hash = hash_password(request.password) user = User( id=user_id, phone=request.phone, password_hash=password_hash, email=request.email, nickname=request.nickname, subscription_type="free", created_at=datetime.now(timezone.utc) ) db.add(user) # 创建刷新令牌 refresh_token_str = create_refresh_token() refresh_token = RefreshToken( id=str(uuid.uuid4()), user_id=user_id, token=refresh_token_str, expires_at=get_token_expires_at(), created_at=datetime.now(timezone.utc) ) db.add(refresh_token) await db.commit() await db.refresh(user) # 生成访问令牌 access_token = create_access_token(data={"sub": user_id}) return TokenResponse( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer" ) @router.post("/login", response_model=TokenResponse) async def login( request: LoginRequest = Body(...), db: AsyncSession = Depends(get_async_db) ): """ 用户登录 验证手机号和密码,返回访问令牌和刷新令牌。 请求格式: - JSON: {"phone": "13800138000", "password": "xxx", "agreed_to_terms": true} """ # 验证是否同意用户协议和隐私政策 if not request.agreed_to_terms: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="请先阅读并同意用户协议和隐私政策" ) phone = request.phone password = request.password # 验证手机号格式(简单验证) if not phone or len(phone) != 11 or not phone.isdigit(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="手机号格式不正确,应为11位数字" ) # 验证密码不为空 if not password: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="密码不能为空" ) # 查找用户 stmt = select(User).where(User.phone == phone) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="手机号或密码错误" ) # 验证密码 if not verify_password(password, user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="手机号或密码错误" ) # 创建新的刷新令牌 refresh_token_str = create_refresh_token() refresh_token = RefreshToken( id=str(uuid.uuid4()), user_id=user.id, token=refresh_token_str, expires_at=get_token_expires_at(), created_at=datetime.now(timezone.utc) ) db.add(refresh_token) await db.commit() # 生成访问令牌 access_token = create_access_token(data={"sub": user.id}) return TokenResponse( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer" ) @router.post("/refresh", response_model=TokenResponse) async def refresh_token( request: RefreshTokenRequest, db: AsyncSession = Depends(get_async_db) ): """ 刷新访问令牌 使用刷新令牌获取新的访问令牌 """ # 查找刷新令牌 stmt = select(RefreshToken).where(RefreshToken.token == request.refresh_token) result = await db.execute(stmt) refresh_token = result.scalar_one_or_none() if not refresh_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的刷新令牌" ) # 检查是否已撤销 if refresh_token.is_revoked: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌已撤销" ) # 检查是否过期 if refresh_token.expires_at < datetime.now(timezone.utc): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌已过期" ) # 验证用户是否存在 user = await db.get(User, refresh_token.user_id) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在" ) # 生成新的访问令牌 access_token = create_access_token(data={"sub": user.id}) return TokenResponse( access_token=access_token, refresh_token=request.refresh_token, # 返回相同的刷新令牌 token_type="bearer" ) @router.post("/logout", status_code=status.HTTP_200_OK) async def logout( request: RefreshTokenRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ 用户登出 撤销刷新令牌 """ # 查找刷新令牌 stmt = select(RefreshToken).where( RefreshToken.token == request.refresh_token, RefreshToken.user_id == current_user.id ) result = await db.execute(stmt) refresh_token = result.scalar_one_or_none() if refresh_token: refresh_token.is_revoked = True await db.commit() return {"message": "登出成功"} @router.get("/me", response_model=UserResponse) async def get_me( current_user: User = Depends(get_current_user) ): """ 获取当前用户信息 """ return UserResponse( id=current_user.id, phone=current_user.phone, email=current_user.email, nickname=current_user.nickname, avatar_url=current_user.avatar_url, subscription_type=current_user.subscription_type, created_at=current_user.created_at.isoformat() ) # 头像存储目录 AVATAR_DIR = Path("uploads/avatars") AVATAR_DIR.mkdir(parents=True, exist_ok=True) @router.post("/me/avatar", response_model=UserResponse) async def upload_avatar( file: UploadFile = File(...), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ 上传用户头像 支持格式:JPEG, PNG, WebP 最大大小:5MB 自动裁剪为正方形并压缩 """ # 验证文件类型 allowed_types = ["image/jpeg", "image/png", "image/webp"] logger.info(f"上传头像 - 文件名: {file.filename}, Content-Type: {file.content_type}, Size: {file.size if hasattr(file, 'size') else 'unknown'}") if file.content_type not in allowed_types: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"不支持的文件类型。仅支持: {', '.join(allowed_types)}" ) # 验证文件大小(5MB) file_content = await file.read() logger.info(f"读取文件内容 - 大小: {len(file_content)} bytes") # 验证文件内容不为空 if not file_content or len(file_content) == 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="文件内容为空" ) if len(file_content) > 5 * 1024 * 1024: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="文件大小超过5MB限制" ) try: # 确保目录存在 AVATAR_DIR.mkdir(parents=True, exist_ok=True) # 创建BytesIO对象并确保位置在开头 image_bytes = io.BytesIO(file_content) image_bytes.seek(0) # 验证文件是否为有效的图片格式(通过检查文件头) image_bytes.seek(0) header = image_bytes.read(16) # 读取更多字节以确保能检测WebP image_bytes.seek(0) # 检查常见图片格式的文件头 is_valid_image = False if header.startswith(b'\xff\xd8\xff'): # JPEG is_valid_image = True logger.info("检测到JPEG格式") elif header.startswith(b'\x89PNG\r\n\x1a\n'): # PNG is_valid_image = True logger.info("检测到PNG格式") elif header.startswith(b'RIFF') and b'WEBP' in header[:12]: # WebP (RIFF...WEBP) is_valid_image = True logger.info("检测到WebP格式") else: logger.warning(f"无法识别的文件头: {header[:12].hex()}") if not is_valid_image: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的图片文件格式。文件头: {header[:12].hex()}" ) # 打开并处理图片 image = Image.open(image_bytes) logger.info(f"成功打开图片 - 格式: {image.format}, 模式: {image.mode}, 尺寸: {image.size}") # 转换为RGB模式(处理RGBA等格式) if image.mode != "RGB": image = image.convert("RGB") # 裁剪为正方形(取中心部分) width, height = image.size size = min(width, height) left = (width - size) // 2 top = (height - size) // 2 right = left + size bottom = top + size image = image.crop((left, top, right, bottom)) # 调整大小(最大512x512) if size > 512: image = image.resize((512, 512), Image.Resampling.LANCZOS) # 生成文件名(使用用户ID) file_extension = "jpg" filename = f"{current_user.id}.{file_extension}" file_path = AVATAR_DIR / filename # 保存图片(JPEG格式,质量85%) image.save(file_path, "JPEG", quality=85, optimize=True) # 生成URL(相对路径,前端需要拼接BASE_URL) avatar_url = f"/api/auth/avatars/{filename}" # 更新数据库 current_user.avatar_url = avatar_url await db.commit() await db.refresh(current_user) return UserResponse( id=current_user.id, phone=current_user.phone, email=current_user.email, nickname=current_user.nickname, avatar_url=current_user.avatar_url, subscription_type=current_user.subscription_type, created_at=current_user.created_at.isoformat() ) except Exception as e: error_msg = f"处理图片失败: {str(e)}" logger.error(f"头像上传失败: {error_msg}\n{traceback.format_exc()}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error_msg ) @router.get("/avatars/{filename}") async def get_avatar(filename: str): """ 获取用户头像 返回头像文件,如果不存在则返回404 """ file_path = AVATAR_DIR / filename if not file_path.exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="头像不存在" ) return FileResponse( file_path, media_type="image/jpeg" ) # ============================================================================ # 短信验证码相关路由 # ============================================================================ @router.post("/sms/send") async def send_sms_code( request: SmsRequest, db: AsyncSession = Depends(get_async_db) ): """ 发送短信验证码 用途: - login: 登录/注册(统一接口,用户不存在时自动注册) - register: 注册(保留兼容,建议使用login) - reset_password: 重置密码 - change_phone: 修改手机号 """ # 验证手机号格式 if not request.phone.isdigit(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="手机号格式不正确" ) # 验证用途 valid_purposes = ["register", "login", "reset_password", "change_phone"] if request.purpose not in valid_purposes: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的用途,必须是: {', '.join(valid_purposes)}" ) # 对于注册,检查手机号是否已存在 if request.purpose == "register": stmt = select(User).where(User.phone == request.phone) result = await db.execute(stmt) existing_user = result.scalar_one_or_none() if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该手机号已被注册" ) # 对于重置密码,检查手机号是否存在 if request.purpose == "reset_password": stmt = select(User).where(User.phone == request.phone) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="该手机号未注册" ) # 对于登录,不再检查用户是否存在,允许自动注册 # 发送验证码时统一使用login目的,登录接口会自动处理注册逻辑 # 发送验证码 success, message, expires_in = await send_verification_code( db=db, phone=request.phone, purpose=request.purpose, ip_address=None # TODO: 从request中获取IP地址 ) if not success: # 根据错误类型返回适当的HTTP状态码 if "频繁" in message: status_code = status.HTTP_429_TOO_MANY_REQUESTS elif "配置" in message or "配置错误" in message or "授权失败" in message: # 配置错误或授权失败,返回503服务不可用 status_code = status.HTTP_503_SERVICE_UNAVAILABLE else: # 其他错误返回500内部服务器错误 status_code = status.HTTP_500_INTERNAL_SERVER_ERROR raise HTTPException( status_code=status_code, detail=message ) return {"message": message, "expires_in": expires_in} @router.post("/login/sms", response_model=TokenResponse) async def login_with_sms( request: SmsLoginRequest, db: AsyncSession = Depends(get_async_db) ): """ 验证码登录/注册(统一接口) 使用手机号和验证码登录,如果用户不存在则自动注册。 注册时需要提供昵称。 """ # 验证是否同意用户协议和隐私政策 if not request.agreed_to_terms: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="请先阅读并同意用户协议和隐私政策" ) # 验证验证码(支持login和register两种目的,兼容旧流程) # 先尝试用login验证,如果失败再尝试register验证 success = False message = "" for purpose in ["login", "register"]: success, message = await verify_code( db=db, phone=request.phone, code=request.code, purpose=purpose ) if success: break if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=message ) # 查找用户 stmt = select(User).where(User.phone == request.phone) result = await db.execute(stmt) user = result.scalar_one_or_none() # 如果用户不存在,自动注册 if not user: # 创建新用户 user_id = str(uuid.uuid4()) # 生成一个随机密码(用户不会用到,但数据库要求必填) random_password = secrets.token_urlsafe(32) password_hash = hash_password(random_password) # 如果提供了昵称就使用,否则使用空字符串(表示需要后续设置) nickname = request.nickname.strip() if request.nickname else "" user = User( id=user_id, phone=request.phone, password_hash=password_hash, email=None, nickname=nickname, subscription_type="free", created_at=datetime.now(timezone.utc) ) db.add(user) await db.commit() await db.refresh(user) # 创建刷新令牌 refresh_token_str = create_refresh_token() refresh_token = RefreshToken( id=str(uuid.uuid4()), user_id=user.id, token=refresh_token_str, expires_at=get_token_expires_at(), created_at=datetime.now(timezone.utc) ) db.add(refresh_token) await db.commit() # 生成访问令牌 access_token = create_access_token(data={"sub": user.id}) return TokenResponse( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer" ) @router.post("/register/sms", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) async def register_with_sms( request: SmsRegisterRequest, db: AsyncSession = Depends(get_async_db) ): """ 验证码注册 使用验证码验证后完成注册 """ # 验证是否同意用户协议和隐私政策 if not request.agreed_to_terms: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="请先阅读并同意用户协议和隐私政策" ) # 验证验证码 success, message = await verify_code( db=db, phone=request.phone, code=request.code, purpose="register" ) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=message ) # 检查手机号是否已存在(双重验证) stmt = select(User).where(User.phone == request.phone) result = await db.execute(stmt) existing_user = result.scalar_one_or_none() if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该手机号已被注册" ) # 检查邮箱是否已存在(如果提供了邮箱) if request.email: stmt = select(User).where(User.email == request.email) result = await db.execute(stmt) existing_email = result.scalar_one_or_none() if existing_email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该邮箱已被注册" ) # 创建新用户 user_id = str(uuid.uuid4()) password_hash = hash_password(request.password) user = User( id=user_id, phone=request.phone, password_hash=password_hash, email=request.email, nickname=request.nickname, subscription_type="free", created_at=datetime.now(timezone.utc) ) db.add(user) # 创建刷新令牌 refresh_token_str = create_refresh_token() refresh_token = RefreshToken( id=str(uuid.uuid4()), user_id=user_id, token=refresh_token_str, expires_at=get_token_expires_at(), created_at=datetime.now(timezone.utc) ) db.add(refresh_token) await db.commit() await db.refresh(user) # 生成访问令牌 access_token = create_access_token(data={"sub": user_id}) return TokenResponse( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer" ) @router.post("/password/reset") async def reset_password( request: ResetPasswordRequest, db: AsyncSession = Depends(get_async_db) ): """ 重置密码(忘记密码) 使用验证码重置密码 """ # 验证验证码 success, message = await verify_code( db=db, phone=request.phone, code=request.code, purpose="reset_password" ) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=message ) # 查找用户 stmt = select(User).where(User.phone == request.phone) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在" ) # 更新密码 user.password_hash = hash_password(request.new_password) await db.commit() return {"message": "密码重置成功"} @router.post("/password/change") async def change_password( request: ChangePasswordRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ 修改密码(已登录) 需要验证旧密码 """ # 验证旧密码 if not verify_password(request.old_password, current_user.password_hash): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="旧密码错误" ) # 更新密码 current_user.password_hash = hash_password(request.new_password) await db.commit() return {"message": "密码修改成功"} @router.post("/phone/change", response_model=UserResponse) async def change_phone( request: ChangePhoneRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ 修改手机号 需要验证新手机号的验证码 """ # 验证验证码 success, message = await verify_code( db=db, phone=request.new_phone, code=request.code, purpose="change_phone" ) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=message ) # 检查新手机号是否已被使用 stmt = select(User).where(User.phone == request.new_phone) result = await db.execute(stmt) existing_user = result.scalar_one_or_none() if existing_user and existing_user.id != current_user.id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该手机号已被其他用户使用" ) # 更新手机号 current_user.phone = request.new_phone await db.commit() await db.refresh(current_user) return UserResponse( id=current_user.id, phone=current_user.phone, email=current_user.email, nickname=current_user.nickname, avatar_url=current_user.avatar_url, subscription_type=current_user.subscription_type, created_at=current_user.created_at.isoformat() ) @router.put("/me/nickname", response_model=UserResponse) async def update_nickname( request: UpdateNicknameRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ 更新用户昵称 用于首次登录后设置昵称,或修改现有昵称 """ # 更新昵称 current_user.nickname = request.nickname.strip() await db.commit() await db.refresh(current_user) return UserResponse( id=current_user.id, phone=current_user.phone, email=current_user.email, nickname=current_user.nickname, avatar_url=current_user.avatar_url, subscription_type=current_user.subscription_type, created_at=current_user.created_at.isoformat() ) @router.post("/logout/all") async def logout_all_devices( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ 登出所有设备 撤销当前用户的所有刷新令牌 """ # 查找用户的所有刷新令牌 stmt = select(RefreshToken).where( RefreshToken.user_id == current_user.id, RefreshToken.is_revoked == False ) result = await db.execute(stmt) tokens = result.scalars().all() # 撤销所有令牌 for token in tokens: token.is_revoked = True await db.commit() return {"message": f"已登出所有设备,共撤销 {len(tokens)} 个令牌"}