refactor: 优化后端认证和路由功能

- 优化auth.py认证路由
- 优化books.py书籍路由
- 优化sms_service.py短信服务
This commit is contained in:
iammm0
2026-01-29 10:57:05 +08:00
parent 6744263773
commit 32fdc066dd
3 changed files with 57 additions and 21 deletions

View File

@@ -6,6 +6,7 @@ import io
import os import os
import logging import logging
import traceback import traceback
import secrets
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -63,10 +64,11 @@ class SmsRequest(BaseModel):
class SmsLoginRequest(BaseModel): class SmsLoginRequest(BaseModel):
"""验证码登录请求""" """验证码登录/注册请求(统一接口)"""
phone: str = Field(..., min_length=11, max_length=11, description="手机号11位") phone: str = Field(..., min_length=11, max_length=11, description="手机号11位")
code: str = Field(..., min_length=6, max_length=6, description="验证码6位") code: str = Field(..., min_length=6, max_length=6, description="验证码6位")
agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策")
nickname: Optional[str] = Field(None, max_length=50, description="昵称(注册时必填,登录时可选)")
class SmsRegisterRequest(BaseModel): class SmsRegisterRequest(BaseModel):
@@ -539,8 +541,8 @@ async def send_sms_code(
发送短信验证码 发送短信验证码
用途: 用途:
- register: 注册 - login: 登录/注册(统一接口,用户不存在时自动注册)
- login: 登录 - register: 注册(保留兼容,建议使用login
- reset_password: 重置密码 - reset_password: 重置密码
- change_phone: 修改手机号 - change_phone: 修改手机号
""" """
@@ -570,8 +572,8 @@ async def send_sms_code(
detail="该手机号已被注册" detail="该手机号已被注册"
) )
# 对于登录和重置密码,检查手机号是否存在 # 对于重置密码,检查手机号是否存在
if request.purpose in ["login", "reset_password"]: if request.purpose == "reset_password":
stmt = select(User).where(User.phone == request.phone) stmt = select(User).where(User.phone == request.phone)
result = await db.execute(stmt) result = await db.execute(stmt)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
@@ -581,6 +583,9 @@ async def send_sms_code(
detail="该手机号未注册" detail="该手机号未注册"
) )
# 对于登录,不再检查用户是否存在,允许自动注册
# 发送验证码时统一使用login目的登录接口会自动处理注册逻辑
# 发送验证码 # 发送验证码
success, message, expires_in = await send_verification_code( success, message, expires_in = await send_verification_code(
db=db, db=db,
@@ -614,9 +619,10 @@ async def login_with_sms(
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
""" """
验证码登录 验证码登录/注册(统一接口)
使用手机号和验证码登录 使用手机号和验证码登录,如果用户不存在则自动注册。
注册时需要提供昵称。
""" """
# 验证是否同意用户协议和隐私政策 # 验证是否同意用户协议和隐私政策
if not request.agreed_to_terms: if not request.agreed_to_terms:
@@ -625,13 +631,20 @@ async def login_with_sms(
detail="请先阅读并同意用户协议和隐私政策" detail="请先阅读并同意用户协议和隐私政策"
) )
# 验证验证码 # 验证验证码支持login和register两种目的兼容旧流程
# 先尝试用login验证如果失败再尝试register验证
success = False
message = ""
for purpose in ["login", "register"]:
success, message = await verify_code( success, message = await verify_code(
db=db, db=db,
phone=request.phone, phone=request.phone,
code=request.code, code=request.code,
purpose="login" purpose=purpose
) )
if success:
break
if not success: if not success:
raise HTTPException( raise HTTPException(
@@ -644,12 +657,35 @@ async def login_with_sms(
result = await db.execute(stmt) result = await db.execute(stmt)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
# 如果用户不存在,自动注册
if not user: if not user:
# 注册时需要提供昵称
if not request.nickname or len(request.nickname.strip()) == 0:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_400_BAD_REQUEST,
detail="用户不存在" detail="首次登录需要设置昵称"
) )
# 创建新用户
user_id = str(uuid.uuid4())
# 生成一个随机密码(用户不会用到,但数据库要求必填)
random_password = secrets.token_urlsafe(32)
password_hash = hash_password(random_password)
user = User(
id=user_id,
phone=request.phone,
password_hash=password_hash,
email=None,
nickname=request.nickname.strip(),
subscription_type="free",
created_at=datetime.now(timezone.utc)
)
db.add(user)
await db.commit()
await db.refresh(user)
# 创建刷新令牌 # 创建刷新令牌
refresh_token_str = create_refresh_token() refresh_token_str = create_refresh_token()
refresh_token = RefreshToken( refresh_token = RefreshToken(

View File

@@ -21,7 +21,7 @@ async def get_current_book(
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""获取当前回忆录(需要认证)""" """获取当前回忆录(需要认证)"""
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()) stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()).limit(1)
result = await db.execute(stmt) result = await db.execute(stmt)
book = result.scalar_one_or_none() book = result.scalar_one_or_none()
@@ -45,7 +45,7 @@ async def clear_book_update(
db: AsyncSession = Depends(get_async_db), db: AsyncSession = Depends(get_async_db),
): ):
"""清除回忆录更新标记""" """清除回忆录更新标记"""
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()) stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()).limit(1)
result = await db.execute(stmt) result = await db.execute(stmt)
book = result.scalar_one_or_none() book = result.scalar_one_or_none()
if not book: if not book:

View File

@@ -166,7 +166,7 @@ async def check_rate_limit(db: AsyncSession, phone: str) -> Tuple[bool, int]:
SmsVerificationCode.phone == phone SmsVerificationCode.phone == phone
).order_by( ).order_by(
SmsVerificationCode.created_at.desc() SmsVerificationCode.created_at.desc()
) ).limit(1)
result = await db.execute(stmt) result = await db.execute(stmt)
recent_code = result.scalar_one_or_none() recent_code = result.scalar_one_or_none()
@@ -262,7 +262,7 @@ async def verify_code(
SmsVerificationCode.is_expired == False SmsVerificationCode.is_expired == False
).order_by( ).order_by(
SmsVerificationCode.created_at.desc() SmsVerificationCode.created_at.desc()
) ).limit(1)
result = await db.execute(stmt) result = await db.execute(stmt)
verification = result.scalar_one_or_none() verification = result.scalar_one_or_none()