547 lines
16 KiB
Python
547 lines
16 KiB
Python
import io
|
|
from pathlib import Path
|
|
|
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
|
from fastapi.responses import FileResponse
|
|
from PIL import Image
|
|
|
|
from app.core.config import settings
|
|
from app.core.dependencies import get_current_user
|
|
from app.core.logging import get_logger
|
|
from app.features.auth.deps import get_auth_service
|
|
from app.features.auth.schemas import (
|
|
ChangePasswordRequest,
|
|
ChangePhoneRequest,
|
|
LoginRequest,
|
|
MockSmsLoginRequest,
|
|
RefreshTokenRequest,
|
|
RegisterRequest,
|
|
ResetPasswordRequest,
|
|
SendSmsRequest,
|
|
SmsLoginRequest,
|
|
SmsRegisterRequest,
|
|
TokenResponse,
|
|
UpdateNicknameRequest,
|
|
UserResponse,
|
|
)
|
|
from app.features.auth.service import AuthError, AuthService
|
|
from app.features.user.models import User
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/api/auth",
|
|
tags=["auth"],
|
|
responses={401: {"description": "认证失败"}},
|
|
)
|
|
|
|
AVATAR_DIR = Path("uploads/avatars")
|
|
AVATAR_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# ── helpers ──────────────────────────────────────────────────
|
|
|
|
_ERROR_STATUS: dict[str, int] = {
|
|
"INVALID_CREDENTIALS": status.HTTP_401_UNAUTHORIZED,
|
|
"INVALID_TOKEN": status.HTTP_401_UNAUTHORIZED,
|
|
"TOKEN_REVOKED": status.HTTP_401_UNAUTHORIZED,
|
|
"TOKEN_EXPIRED": status.HTTP_401_UNAUTHORIZED,
|
|
"USER_NOT_FOUND": status.HTTP_404_NOT_FOUND,
|
|
"PHONE_EXISTS": status.HTTP_400_BAD_REQUEST,
|
|
}
|
|
|
|
|
|
def _map_auth_error(e: AuthError) -> HTTPException:
|
|
code = _ERROR_STATUS.get(e.code, status.HTTP_400_BAD_REQUEST)
|
|
return HTTPException(status_code=code, detail=e.message)
|
|
|
|
|
|
def _user_response(user: User) -> UserResponse:
|
|
return UserResponse(
|
|
id=user.id,
|
|
phone=user.phone,
|
|
email=user.email,
|
|
nickname=user.nickname,
|
|
avatar_url=user.avatar_url,
|
|
subscription_type=user.subscription_type,
|
|
created_at=user.created_at.isoformat(),
|
|
)
|
|
|
|
|
|
def _check_terms(agreed: bool) -> None:
|
|
if not agreed:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="请先阅读并同意用户协议和隐私政策",
|
|
)
|
|
|
|
|
|
def _mock_sms_login_route_enabled() -> bool:
|
|
env = (settings.app_environment or "").lower().strip()
|
|
if env == "production":
|
|
return False
|
|
return bool(settings.mock_sms_login_enabled)
|
|
|
|
|
|
# ── registration & login ─────────────────────────────────────
|
|
|
|
|
|
@router.post(
|
|
"/register",
|
|
response_model=TokenResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
summary="手机号密码注册",
|
|
responses={400: {"description": "手机号/邮箱已注册或参数错误"}},
|
|
)
|
|
async def register(
|
|
request: RegisterRequest,
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
_check_terms(request.agreed_to_terms)
|
|
try:
|
|
result = await service.register(
|
|
phone=request.phone,
|
|
password=request.password,
|
|
nickname=request.nickname,
|
|
email=request.email,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return TokenResponse(
|
|
access_token=result["access_token"],
|
|
refresh_token=result["refresh_token"],
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/login",
|
|
response_model=TokenResponse,
|
|
summary="手机号密码登录",
|
|
responses={401: {"description": "手机号或密码错误"}},
|
|
)
|
|
async def login(
|
|
request: LoginRequest,
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
_check_terms(request.agreed_to_terms)
|
|
try:
|
|
result = await service.login(
|
|
phone=request.phone,
|
|
password=request.password,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return TokenResponse(
|
|
access_token=result["access_token"],
|
|
refresh_token=result["refresh_token"],
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/refresh",
|
|
response_model=TokenResponse,
|
|
summary="刷新访问令牌",
|
|
responses={401: {"description": "刷新令牌无效/已撤销/已过期"}},
|
|
)
|
|
async def refresh_token(
|
|
request: RefreshTokenRequest,
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
try:
|
|
result = await service.refresh_tokens(
|
|
refresh_token=request.refresh_token,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return TokenResponse(
|
|
access_token=result["access_token"],
|
|
refresh_token=result["refresh_token"],
|
|
)
|
|
|
|
|
|
# ── logout ────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post(
|
|
"/logout",
|
|
status_code=status.HTTP_200_OK,
|
|
summary="登出当前设备",
|
|
)
|
|
async def logout(
|
|
request: RefreshTokenRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
await service.logout(request.refresh_token, current_user.id)
|
|
return {"message": "登出成功"}
|
|
|
|
|
|
@router.post(
|
|
"/logout/all",
|
|
summary="登出所有设备",
|
|
)
|
|
async def logout_all_devices(
|
|
current_user: User = Depends(get_current_user),
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
count = await service.logout_all(current_user.id)
|
|
return {"message": f"已登出所有设备,共撤销 {count} 个令牌"}
|
|
|
|
|
|
# ── user profile ──────────────────────────────────────────────
|
|
|
|
|
|
@router.get(
|
|
"/me",
|
|
response_model=UserResponse,
|
|
summary="获取当前用户信息",
|
|
)
|
|
async def get_me(
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
return _user_response(current_user)
|
|
|
|
|
|
@router.put(
|
|
"/me/nickname",
|
|
response_model=UserResponse,
|
|
summary="修改昵称",
|
|
)
|
|
async def update_nickname(
|
|
request: UpdateNicknameRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
try:
|
|
user = await service.update_nickname(current_user.id, request.nickname)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return _user_response(user)
|
|
|
|
|
|
# ── avatar ────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post(
|
|
"/me/avatar",
|
|
response_model=UserResponse,
|
|
summary="上传头像",
|
|
responses={400: {"description": "文件类型或大小不符合要求"}},
|
|
)
|
|
async def upload_avatar(
|
|
file: UploadFile = File(...),
|
|
current_user: User = Depends(get_current_user),
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
allowed_types = ["image/jpeg", "image/png", "image/webp"]
|
|
|
|
if file.content_type not in allowed_types:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"不支持的文件类型。仅支持: {', '.join(allowed_types)}",
|
|
)
|
|
|
|
file_content = await file.read()
|
|
|
|
if not file_content or len(file_content) == 0:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="文件内容为空",
|
|
)
|
|
|
|
if len(file_content) > 5 * 1024 * 1024:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="文件大小超过5MB限制",
|
|
)
|
|
|
|
logger.debug(
|
|
"上传头像: user_id={} filename={} content_type={} size={}",
|
|
current_user.id,
|
|
file.filename,
|
|
file.content_type,
|
|
len(file_content),
|
|
)
|
|
|
|
try:
|
|
AVATAR_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
image_bytes = io.BytesIO(file_content)
|
|
image_bytes.seek(0)
|
|
|
|
header = image_bytes.read(16)
|
|
image_bytes.seek(0)
|
|
|
|
is_valid_image = False
|
|
if header.startswith(b"\xff\xd8\xff"):
|
|
is_valid_image = True
|
|
elif header.startswith(b"\x89PNG\r\n\x1a\n"):
|
|
is_valid_image = True
|
|
elif header.startswith(b"RIFF") and b"WEBP" in header[:12]:
|
|
is_valid_image = True
|
|
else:
|
|
logger.warning("无法识别的图片文件头")
|
|
logger.debug("无法识别的文件头 hex={}", header[:12].hex())
|
|
|
|
if not is_valid_image:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"无效的图片文件格式。文件头: {header[:12].hex()}",
|
|
)
|
|
|
|
image = Image.open(image_bytes)
|
|
logger.debug(
|
|
"头像解码: format={} mode={} size={}",
|
|
image.format,
|
|
image.mode,
|
|
image.size,
|
|
)
|
|
|
|
if image.mode != "RGB":
|
|
image = image.convert("RGB")
|
|
|
|
width, height = image.size
|
|
size = min(width, height)
|
|
left = (width - size) // 2
|
|
top = (height - size) // 2
|
|
right = left + size
|
|
bottom = top + size
|
|
image = image.crop((left, top, right, bottom))
|
|
|
|
if size > 512:
|
|
image = image.resize((512, 512), Image.Resampling.LANCZOS)
|
|
|
|
file_extension = "jpg"
|
|
filename = f"{current_user.id}.{file_extension}"
|
|
file_path = AVATAR_DIR / filename
|
|
|
|
image.save(file_path, "JPEG", quality=85, optimize=True)
|
|
|
|
avatar_url = f"/api/auth/avatars/{filename}"
|
|
user = await service.update_avatar_url(current_user.id, avatar_url)
|
|
return _user_response(user)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("头像上传失败: {}", e)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="处理图片失败,请重试",
|
|
) from e
|
|
|
|
|
|
@router.get(
|
|
"/avatars/{filename}",
|
|
summary="获取头像图片",
|
|
responses={404: {"description": "头像不存在"}},
|
|
)
|
|
async def get_avatar(filename: str):
|
|
file_path = AVATAR_DIR / filename
|
|
if not file_path.exists():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="头像不存在",
|
|
)
|
|
return FileResponse(file_path, media_type="image/jpeg")
|
|
|
|
|
|
# ── SMS verification ──────────────────────────────────────────
|
|
|
|
|
|
@router.post(
|
|
"/sms/send",
|
|
summary="发送短信验证码",
|
|
responses={
|
|
400: {"description": "手机号格式或用途不合法"},
|
|
429: {"description": "发送过于频繁"},
|
|
503: {"description": "短信服务不可用"},
|
|
},
|
|
)
|
|
async def send_sms_code(
|
|
request: SendSmsRequest,
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
if not request.phone.isdigit():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="手机号格式不正确",
|
|
)
|
|
|
|
valid_purposes = ["register", "login", "reset_password", "change_phone"]
|
|
if request.purpose not in valid_purposes:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"无效的用途,必须是: {', '.join(valid_purposes)}",
|
|
)
|
|
|
|
try:
|
|
success, message, expires_in = await service.send_sms_code(
|
|
phone=request.phone,
|
|
purpose=request.purpose,
|
|
ip_address=None,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
|
|
if not success:
|
|
if "频繁" in message:
|
|
status_code = status.HTTP_429_TOO_MANY_REQUESTS
|
|
elif "配置" in message or "配置错误" in message or "授权失败" in message:
|
|
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
|
else:
|
|
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
raise HTTPException(status_code=status_code, detail=message)
|
|
|
|
return {"message": message, "expires_in": expires_in}
|
|
|
|
|
|
@router.post(
|
|
"/login/sms",
|
|
response_model=TokenResponse,
|
|
summary="短信验证码登录(新用户自动注册)",
|
|
responses={400: {"description": "验证码错误"}},
|
|
)
|
|
async def login_with_sms(
|
|
request: SmsLoginRequest,
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
_check_terms(request.agreed_to_terms)
|
|
try:
|
|
result = await service.login_with_sms(
|
|
phone=request.phone,
|
|
code=request.code,
|
|
nickname=request.nickname,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return TokenResponse(
|
|
access_token=result["access_token"],
|
|
refresh_token=result["refresh_token"],
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/mock/sms-login",
|
|
response_model=TokenResponse,
|
|
summary="[评测] Mock 短信登录(跳过验证码)",
|
|
description=(
|
|
"需 MOCK_SMS_LOGIN_ENABLED=1 且 APP_ENV 非 production。"
|
|
"供 Eval Web 等内网工具联调,勿在生产环境开启。"
|
|
),
|
|
responses={404: {"description": "未启用或生产环境已禁用"}},
|
|
)
|
|
async def mock_sms_login_route(
|
|
request: MockSmsLoginRequest,
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
if not _mock_sms_login_route_enabled():
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
|
|
_check_terms(request.agreed_to_terms)
|
|
try:
|
|
result = await service.mock_sms_login(
|
|
phone=request.phone,
|
|
nickname=request.nickname,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return TokenResponse(
|
|
access_token=result["access_token"],
|
|
refresh_token=result["refresh_token"],
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/register/sms",
|
|
response_model=TokenResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
summary="短信验证码注册",
|
|
responses={400: {"description": "验证码错误或手机号/邮箱已注册"}},
|
|
)
|
|
async def register_with_sms(
|
|
request: SmsRegisterRequest,
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
_check_terms(request.agreed_to_terms)
|
|
try:
|
|
result = await service.register_with_sms(
|
|
phone=request.phone,
|
|
code=request.code,
|
|
password=request.password,
|
|
nickname=request.nickname,
|
|
email=request.email,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return TokenResponse(
|
|
access_token=result["access_token"],
|
|
refresh_token=result["refresh_token"],
|
|
)
|
|
|
|
|
|
# ── password & phone management ───────────────────────────────
|
|
|
|
|
|
@router.post(
|
|
"/password/reset",
|
|
summary="通过短信验证码重置密码",
|
|
responses={
|
|
400: {"description": "验证码错误"},
|
|
404: {"description": "用户不存在"},
|
|
},
|
|
)
|
|
async def reset_password(
|
|
request: ResetPasswordRequest,
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
try:
|
|
await service.reset_password(
|
|
phone=request.phone,
|
|
code=request.code,
|
|
new_password=request.new_password,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return {"message": "密码重置成功"}
|
|
|
|
|
|
@router.post(
|
|
"/password/change",
|
|
summary="修改密码(需旧密码)",
|
|
responses={400: {"description": "旧密码错误"}},
|
|
)
|
|
async def change_password(
|
|
request: ChangePasswordRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
try:
|
|
await service.change_password(
|
|
user_id=current_user.id,
|
|
old_password=request.old_password,
|
|
new_password=request.new_password,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return {"message": "密码修改成功"}
|
|
|
|
|
|
@router.post(
|
|
"/phone/change",
|
|
response_model=UserResponse,
|
|
summary="更换手机号",
|
|
responses={400: {"description": "验证码错误或手机号已被占用"}},
|
|
)
|
|
async def change_phone(
|
|
request: ChangePhoneRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
service: AuthService = Depends(get_auth_service),
|
|
):
|
|
try:
|
|
user = await service.change_phone(
|
|
user_id=current_user.id,
|
|
new_phone=request.new_phone,
|
|
code=request.code,
|
|
)
|
|
except AuthError as e:
|
|
raise _map_auth_error(e)
|
|
return _user_response(user)
|