import io from app.core.logging import get_logger import traceback from pathlib import Path from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi.responses import FileResponse from PIL import Image from app.core.dependencies import get_current_user from app.features.auth.deps import get_auth_service from app.features.auth.schemas import ( ChangePasswordRequest, ChangePhoneRequest, LoginRequest, RefreshTokenRequest, RegisterRequest, ResetPasswordRequest, SendSmsRequest, SmsLoginRequest, SmsRegisterRequest, TokenResponse, UpdateNicknameRequest, UserResponse, ) from app.features.auth.service import AuthError, AuthService from app.features.user.models import User logger = get_logger(__name__) router = APIRouter( prefix="/api/auth", tags=["auth"], responses={401: {"description": "认证失败"}}, ) AVATAR_DIR = Path("uploads/avatars") AVATAR_DIR.mkdir(parents=True, exist_ok=True) # ── helpers ────────────────────────────────────────────────── _ERROR_STATUS: dict[str, int] = { "INVALID_CREDENTIALS": status.HTTP_401_UNAUTHORIZED, "INVALID_TOKEN": status.HTTP_401_UNAUTHORIZED, "TOKEN_REVOKED": status.HTTP_401_UNAUTHORIZED, "TOKEN_EXPIRED": status.HTTP_401_UNAUTHORIZED, "USER_NOT_FOUND": status.HTTP_404_NOT_FOUND, "PHONE_EXISTS": status.HTTP_400_BAD_REQUEST, } def _map_auth_error(e: AuthError) -> HTTPException: code = _ERROR_STATUS.get(e.code, status.HTTP_400_BAD_REQUEST) return HTTPException(status_code=code, detail=e.message) def _user_response(user: User) -> UserResponse: return UserResponse( id=user.id, phone=user.phone, email=user.email, nickname=user.nickname, avatar_url=user.avatar_url, subscription_type=user.subscription_type, created_at=user.created_at.isoformat(), ) def _check_terms(agreed: bool) -> None: if not agreed: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="请先阅读并同意用户协议和隐私政策", ) # ── registration & login ───────────────────────────────────── @router.post( "/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED, summary="手机号密码注册", responses={400: {"description": "手机号/邮箱已注册或参数错误"}}, ) async def register( request: RegisterRequest, service: AuthService = Depends(get_auth_service), ): _check_terms(request.agreed_to_terms) try: result = await service.register( phone=request.phone, password=request.password, nickname=request.nickname, email=request.email, ) except AuthError as e: raise _map_auth_error(e) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], ) @router.post( "/login", response_model=TokenResponse, summary="手机号密码登录", responses={401: {"description": "手机号或密码错误"}}, ) async def login( request: LoginRequest, service: AuthService = Depends(get_auth_service), ): _check_terms(request.agreed_to_terms) try: result = await service.login( phone=request.phone, password=request.password, ) except AuthError as e: raise _map_auth_error(e) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], ) @router.post( "/refresh", response_model=TokenResponse, summary="刷新访问令牌", responses={401: {"description": "刷新令牌无效/已撤销/已过期"}}, ) async def refresh_token( request: RefreshTokenRequest, service: AuthService = Depends(get_auth_service), ): try: result = await service.refresh_tokens( refresh_token=request.refresh_token, ) except AuthError as e: raise _map_auth_error(e) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], ) # ── logout ──────────────────────────────────────────────────── @router.post( "/logout", status_code=status.HTTP_200_OK, summary="登出当前设备", ) async def logout( request: RefreshTokenRequest, current_user: User = Depends(get_current_user), service: AuthService = Depends(get_auth_service), ): await service.logout(request.refresh_token, current_user.id) return {"message": "登出成功"} @router.post( "/logout/all", summary="登出所有设备", ) async def logout_all_devices( current_user: User = Depends(get_current_user), service: AuthService = Depends(get_auth_service), ): count = await service.logout_all(current_user.id) return {"message": f"已登出所有设备,共撤销 {count} 个令牌"} # ── user profile ────────────────────────────────────────────── @router.get( "/me", response_model=UserResponse, summary="获取当前用户信息", ) async def get_me( current_user: User = Depends(get_current_user), ): return _user_response(current_user) @router.put( "/me/nickname", response_model=UserResponse, summary="修改昵称", ) async def update_nickname( request: UpdateNicknameRequest, current_user: User = Depends(get_current_user), service: AuthService = Depends(get_auth_service), ): try: user = await service.update_nickname(current_user.id, request.nickname) except AuthError as e: raise _map_auth_error(e) return _user_response(user) # ── avatar ──────────────────────────────────────────────────── @router.post( "/me/avatar", response_model=UserResponse, summary="上传头像", responses={400: {"description": "文件类型或大小不符合要求"}}, ) async def upload_avatar( file: UploadFile = File(...), current_user: User = Depends(get_current_user), service: AuthService = Depends(get_auth_service), ): allowed_types = ["image/jpeg", "image/png", "image/webp"] logger.info( f"上传头像 - 文件名: {file.filename}, Content-Type: {file.content_type}, " f"Size: {file.size if hasattr(file, 'size') else 'unknown'}" ) if file.content_type not in allowed_types: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"不支持的文件类型。仅支持: {', '.join(allowed_types)}", ) file_content = await file.read() logger.info(f"读取文件内容 - 大小: {len(file_content)} bytes") if not file_content or len(file_content) == 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="文件内容为空", ) if len(file_content) > 5 * 1024 * 1024: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="文件大小超过5MB限制", ) try: AVATAR_DIR.mkdir(parents=True, exist_ok=True) image_bytes = io.BytesIO(file_content) image_bytes.seek(0) header = image_bytes.read(16) image_bytes.seek(0) is_valid_image = False if header.startswith(b"\xff\xd8\xff"): is_valid_image = True logger.info("检测到JPEG格式") elif header.startswith(b"\x89PNG\r\n\x1a\n"): is_valid_image = True logger.info("检测到PNG格式") elif header.startswith(b"RIFF") and b"WEBP" in header[:12]: is_valid_image = True logger.info("检测到WebP格式") else: logger.warning(f"无法识别的文件头: {header[:12].hex()}") if not is_valid_image: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的图片文件格式。文件头: {header[:12].hex()}", ) image = Image.open(image_bytes) logger.info( f"成功打开图片 - 格式: {image.format}, 模式: {image.mode}, 尺寸: {image.size}" ) if image.mode != "RGB": image = image.convert("RGB") width, height = image.size size = min(width, height) left = (width - size) // 2 top = (height - size) // 2 right = left + size bottom = top + size image = image.crop((left, top, right, bottom)) if size > 512: image = image.resize((512, 512), Image.Resampling.LANCZOS) file_extension = "jpg" filename = f"{current_user.id}.{file_extension}" file_path = AVATAR_DIR / filename image.save(file_path, "JPEG", quality=85, optimize=True) avatar_url = f"/api/auth/avatars/{filename}" user = await service.update_avatar_url(current_user.id, avatar_url) return _user_response(user) except HTTPException: raise except Exception as e: error_msg = f"处理图片失败: {str(e)}" logger.error(f"头像上传失败: {error_msg}\n{traceback.format_exc()}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error_msg, ) @router.get( "/avatars/{filename}", summary="获取头像图片", responses={404: {"description": "头像不存在"}}, ) async def get_avatar(filename: str): file_path = AVATAR_DIR / filename if not file_path.exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="头像不存在", ) return FileResponse(file_path, media_type="image/jpeg") # ── SMS verification ────────────────────────────────────────── @router.post( "/sms/send", summary="发送短信验证码", responses={ 400: {"description": "手机号格式或用途不合法"}, 429: {"description": "发送过于频繁"}, 503: {"description": "短信服务不可用"}, }, ) async def send_sms_code( request: SendSmsRequest, service: AuthService = Depends(get_auth_service), ): if not request.phone.isdigit(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="手机号格式不正确", ) valid_purposes = ["register", "login", "reset_password", "change_phone"] if request.purpose not in valid_purposes: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的用途,必须是: {', '.join(valid_purposes)}", ) try: success, message, expires_in = await service.send_sms_code( phone=request.phone, purpose=request.purpose, ip_address=None, ) except AuthError as e: raise _map_auth_error(e) if not success: if "频繁" in message: status_code = status.HTTP_429_TOO_MANY_REQUESTS elif "配置" in message or "配置错误" in message or "授权失败" in message: status_code = status.HTTP_503_SERVICE_UNAVAILABLE else: status_code = status.HTTP_500_INTERNAL_SERVER_ERROR raise HTTPException(status_code=status_code, detail=message) return {"message": message, "expires_in": expires_in} @router.post( "/login/sms", response_model=TokenResponse, summary="短信验证码登录(新用户自动注册)", responses={400: {"description": "验证码错误"}}, ) async def login_with_sms( request: SmsLoginRequest, service: AuthService = Depends(get_auth_service), ): _check_terms(request.agreed_to_terms) try: result = await service.login_with_sms( phone=request.phone, code=request.code, nickname=request.nickname, ) except AuthError as e: raise _map_auth_error(e) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], ) @router.post( "/register/sms", response_model=TokenResponse, status_code=status.HTTP_201_CREATED, summary="短信验证码注册", responses={400: {"description": "验证码错误或手机号/邮箱已注册"}}, ) async def register_with_sms( request: SmsRegisterRequest, service: AuthService = Depends(get_auth_service), ): _check_terms(request.agreed_to_terms) try: result = await service.register_with_sms( phone=request.phone, code=request.code, password=request.password, nickname=request.nickname, email=request.email, ) except AuthError as e: raise _map_auth_error(e) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], ) # ── password & phone management ─────────────────────────────── @router.post( "/password/reset", summary="通过短信验证码重置密码", responses={ 400: {"description": "验证码错误"}, 404: {"description": "用户不存在"}, }, ) async def reset_password( request: ResetPasswordRequest, service: AuthService = Depends(get_auth_service), ): try: await service.reset_password( phone=request.phone, code=request.code, new_password=request.new_password, ) except AuthError as e: raise _map_auth_error(e) return {"message": "密码重置成功"} @router.post( "/password/change", summary="修改密码(需旧密码)", responses={400: {"description": "旧密码错误"}}, ) async def change_password( request: ChangePasswordRequest, current_user: User = Depends(get_current_user), service: AuthService = Depends(get_auth_service), ): try: await service.change_password( user_id=current_user.id, old_password=request.old_password, new_password=request.new_password, ) except AuthError as e: raise _map_auth_error(e) return {"message": "密码修改成功"} @router.post( "/phone/change", response_model=UserResponse, summary="更换手机号", responses={400: {"description": "验证码错误或手机号已被占用"}}, ) async def change_phone( request: ChangePhoneRequest, current_user: User = Depends(get_current_user), service: AuthService = Depends(get_auth_service), ): try: user = await service.change_phone( user_id=current_user.id, new_phone=request.new_phone, code=request.code, ) except AuthError as e: raise _map_auth_error(e) return _user_response(user)