""" 认证相关 API 路由:注册、登录、刷新令牌、登出 """ import uuid from datetime import datetime, timezone from typing import Optional from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from database import get_async_db from database.models import User, RefreshToken from services.auth_service import ( hash_password, verify_password, create_access_token, create_refresh_token, get_token_expires_at ) from middleware.auth import get_current_user router = APIRouter(prefix="/api/auth", tags=["auth"]) # 请求模型 class RegisterRequest(BaseModel): """用户注册请求""" phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") password: str = Field(..., min_length=6, description="密码(至少6位)") nickname: str = Field(..., min_length=1, max_length=50, description="昵称") email: Optional[str] = Field(None, description="邮箱(可选)") class LoginRequest(BaseModel): """用户登录请求""" phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") password: str = Field(..., min_length=1, description="密码") class RefreshTokenRequest(BaseModel): """刷新令牌请求""" refresh_token: str = Field(..., description="刷新令牌") # 响应模型 class TokenResponse(BaseModel): """令牌响应""" access_token: str refresh_token: str token_type: str = "bearer" class UserResponse(BaseModel): """用户信息响应""" id: str phone: str email: Optional[str] nickname: str avatar_url: Optional[str] subscription_type: str created_at: str @router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) async def register( request: RegisterRequest, db: AsyncSession = Depends(get_async_db) ): """ 用户注册 注册成功后返回访问令牌和刷新令牌 """ # 检查手机号是否已存在 stmt = select(User).where(User.phone == request.phone) result = await db.execute(stmt) existing_user = result.scalar_one_or_none() if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该手机号已被注册" ) # 检查邮箱是否已存在(如果提供了邮箱) if request.email: stmt = select(User).where(User.email == request.email) result = await db.execute(stmt) existing_email = result.scalar_one_or_none() if existing_email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该邮箱已被注册" ) # 创建新用户 user_id = str(uuid.uuid4()) password_hash = hash_password(request.password) user = User( id=user_id, phone=request.phone, password_hash=password_hash, email=request.email, nickname=request.nickname, subscription_type="free", created_at=datetime.now(timezone.utc) ) db.add(user) # 创建刷新令牌 refresh_token_str = create_refresh_token() refresh_token = RefreshToken( id=str(uuid.uuid4()), user_id=user_id, token=refresh_token_str, expires_at=get_token_expires_at(), created_at=datetime.now(timezone.utc) ) db.add(refresh_token) await db.commit() await db.refresh(user) # 生成访问令牌 access_token = create_access_token(data={"sub": user_id}) return TokenResponse( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer" ) @router.post("/login", response_model=TokenResponse) async def login( request: LoginRequest, db: AsyncSession = Depends(get_async_db) ): """ 用户登录 验证手机号和密码,返回访问令牌和刷新令牌 """ # 验证手机号格式(简单验证) if not request.phone or len(request.phone) != 11 or not request.phone.isdigit(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="手机号格式不正确,应为11位数字" ) # 验证密码不为空 if not request.password: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="密码不能为空" ) # 查找用户 stmt = select(User).where(User.phone == request.phone) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="手机号或密码错误" ) # 验证密码 if not verify_password(request.password, user.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="手机号或密码错误" ) # 创建新的刷新令牌 refresh_token_str = create_refresh_token() refresh_token = RefreshToken( id=str(uuid.uuid4()), user_id=user.id, token=refresh_token_str, expires_at=get_token_expires_at(), created_at=datetime.now(timezone.utc) ) db.add(refresh_token) await db.commit() # 生成访问令牌 access_token = create_access_token(data={"sub": user.id}) return TokenResponse( access_token=access_token, refresh_token=refresh_token_str, token_type="bearer" ) @router.post("/refresh", response_model=TokenResponse) async def refresh_token( request: RefreshTokenRequest, db: AsyncSession = Depends(get_async_db) ): """ 刷新访问令牌 使用刷新令牌获取新的访问令牌 """ # 查找刷新令牌 stmt = select(RefreshToken).where(RefreshToken.token == request.refresh_token) result = await db.execute(stmt) refresh_token = result.scalar_one_or_none() if not refresh_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的刷新令牌" ) # 检查是否已撤销 if refresh_token.is_revoked: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌已撤销" ) # 检查是否过期 if refresh_token.expires_at < datetime.now(timezone.utc): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌已过期" ) # 验证用户是否存在 user = await db.get(User, refresh_token.user_id) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在" ) # 生成新的访问令牌 access_token = create_access_token(data={"sub": user.id}) return TokenResponse( access_token=access_token, refresh_token=request.refresh_token, # 返回相同的刷新令牌 token_type="bearer" ) @router.post("/logout", status_code=status.HTTP_200_OK) async def logout( request: RefreshTokenRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ 用户登出 撤销刷新令牌 """ # 查找刷新令牌 stmt = select(RefreshToken).where( RefreshToken.token == request.refresh_token, RefreshToken.user_id == current_user.id ) result = await db.execute(stmt) refresh_token = result.scalar_one_or_none() if refresh_token: refresh_token.is_revoked = True await db.commit() return {"message": "登出成功"} @router.get("/me", response_model=UserResponse) async def get_me( current_user: User = Depends(get_current_user) ): """ 获取当前用户信息 """ return UserResponse( id=current_user.id, phone=current_user.phone, email=current_user.email, nickname=current_user.nickname, avatar_url=current_user.avatar_url, subscription_type=current_user.subscription_type, created_at=current_user.created_at.isoformat() )