diff --git a/api/database/models.py b/api/database/models.py index c4eb4dd..1bf62fe 100644 --- a/api/database/models.py +++ b/api/database/models.py @@ -137,7 +137,24 @@ class RefreshToken(Base): expires_at = Column(DateTime(timezone=True), nullable=False) # 过期时间(30天后) created_at = Column(DateTime(timezone=True), default=utc_now) is_revoked = Column(Boolean, default=False) # 是否已撤销 + device_info = Column(String, nullable=True) # 设备信息(用于全设备登出) # Relationships user = relationship("User", back_populates="refresh_tokens") + +class SmsVerificationCode(Base): + """短信验证码表""" + __tablename__ = "sms_verification_codes" + + id = Column(String, primary_key=True) + phone = Column(String, nullable=False, index=True) # 手机号 + code = Column(String, nullable=False) # 6位验证码 + purpose = Column(String, nullable=False) # register/login/reset_password/change_phone + is_used = Column(Boolean, default=False) # 是否已使用 + is_expired = Column(Boolean, default=False) # 是否已过期 + expires_at = Column(DateTime(timezone=True), nullable=False) # 过期时间(5分钟后) + created_at = Column(DateTime(timezone=True), default=utc_now) + verified_at = Column(DateTime(timezone=True), nullable=True) # 验证时间 + ip_address = Column(String, nullable=True) # 请求IP地址 + diff --git a/api/routers/auth.py b/api/routers/auth.py index c14997c..ca1bfab 100644 --- a/api/routers/auth.py +++ b/api/routers/auth.py @@ -28,6 +28,7 @@ from services.auth_service import ( create_refresh_token, get_token_expires_at ) +from services.sms_service import send_verification_code, verify_code from middleware.auth import get_current_user router = APIRouter(prefix="/api/auth", tags=["auth"]) @@ -53,6 +54,46 @@ class RefreshTokenRequest(BaseModel): refresh_token: str = Field(..., description="刷新令牌") +class SmsRequest(BaseModel): + """发送短信验证码请求""" + phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") + purpose: str = Field(..., description="用途:register/login/reset_password/change_phone") + + +class SmsLoginRequest(BaseModel): + """验证码登录请求""" + phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") + code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") + + +class SmsRegisterRequest(BaseModel): + """验证码注册请求""" + phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") + code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") + password: str = Field(..., min_length=6, description="密码(至少6位)") + nickname: str = Field(..., min_length=1, max_length=50, description="昵称") + email: Optional[str] = Field(None, description="邮箱(可选)") + + +class ResetPasswordRequest(BaseModel): + """重置密码请求""" + phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") + code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") + new_password: str = Field(..., min_length=6, description="新密码(至少6位)") + + +class ChangePasswordRequest(BaseModel): + """修改密码请求""" + old_password: str = Field(..., min_length=1, description="旧密码") + new_password: str = Field(..., min_length=6, description="新密码(至少6位)") + + +class ChangePhoneRequest(BaseModel): + """修改手机号请求""" + new_phone: str = Field(..., min_length=11, max_length=11, description="新手机号(11位)") + code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") + + # 响应模型 class TokenResponse(BaseModel): """令牌响应""" @@ -465,3 +506,366 @@ async def get_avatar(filename: str): file_path, media_type="image/jpeg" ) + + +# ============================================================================ +# 短信验证码相关路由 +# ============================================================================ + +@router.post("/sms/send") +async def send_sms_code( + request: SmsRequest, + db: AsyncSession = Depends(get_async_db) +): + """ + 发送短信验证码 + + 用途: + - register: 注册 + - login: 登录 + - reset_password: 重置密码 + - change_phone: 修改手机号 + """ + # 验证手机号格式 + if not request.phone.isdigit(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="手机号格式不正确" + ) + + # 验证用途 + valid_purposes = ["register", "login", "reset_password", "change_phone"] + if request.purpose not in valid_purposes: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"无效的用途,必须是: {', '.join(valid_purposes)}" + ) + + # 对于注册,检查手机号是否已存在 + if request.purpose == "register": + stmt = select(User).where(User.phone == request.phone) + result = await db.execute(stmt) + existing_user = result.scalar_one_or_none() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="该手机号已被注册" + ) + + # 对于登录和重置密码,检查手机号是否存在 + if request.purpose in ["login", "reset_password"]: + stmt = select(User).where(User.phone == request.phone) + result = await db.execute(stmt) + user = result.scalar_one_or_none() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="该手机号未注册" + ) + + # 发送验证码 + success, message, expires_in = await send_verification_code( + db=db, + phone=request.phone, + purpose=request.purpose, + ip_address=None # TODO: 从request中获取IP地址 + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS if "频繁" in message else status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=message + ) + + return {"message": message, "expires_in": expires_in} + + +@router.post("/login/sms", response_model=TokenResponse) +async def login_with_sms( + request: SmsLoginRequest, + db: AsyncSession = Depends(get_async_db) +): + """ + 验证码登录 + + 使用手机号和验证码登录 + """ + # 验证验证码 + success, message = await verify_code( + db=db, + phone=request.phone, + code=request.code, + purpose="login" + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=message + ) + + # 查找用户 + stmt = select(User).where(User.phone == request.phone) + result = await db.execute(stmt) + user = result.scalar_one_or_none() + + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在" + ) + + # 创建刷新令牌 + refresh_token_str = create_refresh_token() + refresh_token = RefreshToken( + id=str(uuid.uuid4()), + user_id=user.id, + token=refresh_token_str, + expires_at=get_token_expires_at(), + created_at=datetime.now(timezone.utc) + ) + + db.add(refresh_token) + await db.commit() + + # 生成访问令牌 + access_token = create_access_token(data={"sub": user.id}) + + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token_str, + token_type="bearer" + ) + + +@router.post("/register/sms", response_model=TokenResponse, status_code=status.HTTP_201_CREATED) +async def register_with_sms( + request: SmsRegisterRequest, + db: AsyncSession = Depends(get_async_db) +): + """ + 验证码注册 + + 使用验证码验证后完成注册 + """ + # 验证验证码 + success, message = await verify_code( + db=db, + phone=request.phone, + code=request.code, + purpose="register" + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=message + ) + + # 检查手机号是否已存在(双重验证) + stmt = select(User).where(User.phone == request.phone) + result = await db.execute(stmt) + existing_user = result.scalar_one_or_none() + + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="该手机号已被注册" + ) + + # 检查邮箱是否已存在(如果提供了邮箱) + if request.email: + stmt = select(User).where(User.email == request.email) + result = await db.execute(stmt) + existing_email = result.scalar_one_or_none() + + if existing_email: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="该邮箱已被注册" + ) + + # 创建新用户 + user_id = str(uuid.uuid4()) + password_hash = hash_password(request.password) + + user = User( + id=user_id, + phone=request.phone, + password_hash=password_hash, + email=request.email, + nickname=request.nickname, + subscription_type="free", + created_at=datetime.now(timezone.utc) + ) + + db.add(user) + + # 创建刷新令牌 + refresh_token_str = create_refresh_token() + refresh_token = RefreshToken( + id=str(uuid.uuid4()), + user_id=user_id, + token=refresh_token_str, + expires_at=get_token_expires_at(), + created_at=datetime.now(timezone.utc) + ) + + db.add(refresh_token) + await db.commit() + await db.refresh(user) + + # 生成访问令牌 + access_token = create_access_token(data={"sub": user_id}) + + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token_str, + token_type="bearer" + ) + + +@router.post("/password/reset") +async def reset_password( + request: ResetPasswordRequest, + db: AsyncSession = Depends(get_async_db) +): + """ + 重置密码(忘记密码) + + 使用验证码重置密码 + """ + # 验证验证码 + success, message = await verify_code( + db=db, + phone=request.phone, + code=request.code, + purpose="reset_password" + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=message + ) + + # 查找用户 + stmt = select(User).where(User.phone == request.phone) + result = await db.execute(stmt) + user = result.scalar_one_or_none() + + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在" + ) + + # 更新密码 + user.password_hash = hash_password(request.new_password) + await db.commit() + + return {"message": "密码重置成功"} + + +@router.post("/password/change") +async def change_password( + request: ChangePasswordRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_async_db) +): + """ + 修改密码(已登录) + + 需要验证旧密码 + """ + # 验证旧密码 + if not verify_password(request.old_password, current_user.password_hash): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="旧密码错误" + ) + + # 更新密码 + current_user.password_hash = hash_password(request.new_password) + await db.commit() + + return {"message": "密码修改成功"} + + +@router.post("/phone/change", response_model=UserResponse) +async def change_phone( + request: ChangePhoneRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_async_db) +): + """ + 修改手机号 + + 需要验证新手机号的验证码 + """ + # 验证验证码 + success, message = await verify_code( + db=db, + phone=request.new_phone, + code=request.code, + purpose="change_phone" + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=message + ) + + # 检查新手机号是否已被使用 + stmt = select(User).where(User.phone == request.new_phone) + result = await db.execute(stmt) + existing_user = result.scalar_one_or_none() + + if existing_user and existing_user.id != current_user.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="该手机号已被其他用户使用" + ) + + # 更新手机号 + current_user.phone = request.new_phone + await db.commit() + await db.refresh(current_user) + + return UserResponse( + id=current_user.id, + phone=current_user.phone, + email=current_user.email, + nickname=current_user.nickname, + avatar_url=current_user.avatar_url, + subscription_type=current_user.subscription_type, + created_at=current_user.created_at.isoformat() + ) + + +@router.post("/logout/all") +async def logout_all_devices( + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_async_db) +): + """ + 登出所有设备 + + 撤销当前用户的所有刷新令牌 + """ + # 查找用户的所有刷新令牌 + stmt = select(RefreshToken).where( + RefreshToken.user_id == current_user.id, + RefreshToken.is_revoked == False + ) + result = await db.execute(stmt) + tokens = result.scalars().all() + + # 撤销所有令牌 + for token in tokens: + token.is_revoked = True + + await db.commit() + + return {"message": f"已登出所有设备,共撤销 {len(tokens)} 个令牌"}