refactor: 优化后端认证和路由功能

- 优化auth.py认证路由
- 优化books.py书籍路由
- 优化sms_service.py短信服务
This commit is contained in:
iammm0
2026-01-29 10:57:05 +08:00
parent 6744263773
commit 32fdc066dd
3 changed files with 57 additions and 21 deletions

View File

@@ -6,6 +6,7 @@ import io
import os
import logging
import traceback
import secrets
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
@@ -63,10 +64,11 @@ class SmsRequest(BaseModel):
class SmsLoginRequest(BaseModel):
"""验证码登录请求"""
"""验证码登录/注册请求(统一接口)"""
phone: str = Field(..., min_length=11, max_length=11, description="手机号11位")
code: str = Field(..., min_length=6, max_length=6, description="验证码6位")
agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策")
nickname: Optional[str] = Field(None, max_length=50, description="昵称(注册时必填,登录时可选)")
class SmsRegisterRequest(BaseModel):
@@ -539,8 +541,8 @@ async def send_sms_code(
发送短信验证码
用途:
- register: 注册
- login: 登录
- login: 登录/注册(统一接口,用户不存在时自动注册)
- register: 注册(保留兼容,建议使用login
- reset_password: 重置密码
- change_phone: 修改手机号
"""
@@ -570,8 +572,8 @@ async def send_sms_code(
detail="该手机号已被注册"
)
# 对于登录和重置密码,检查手机号是否存在
if request.purpose in ["login", "reset_password"]:
# 对于重置密码,检查手机号是否存在
if request.purpose == "reset_password":
stmt = select(User).where(User.phone == request.phone)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
@@ -581,6 +583,9 @@ async def send_sms_code(
detail="该手机号未注册"
)
# 对于登录,不再检查用户是否存在,允许自动注册
# 发送验证码时统一使用login目的登录接口会自动处理注册逻辑
# 发送验证码
success, message, expires_in = await send_verification_code(
db=db,
@@ -614,9 +619,10 @@ async def login_with_sms(
db: AsyncSession = Depends(get_async_db)
):
"""
验证码登录
验证码登录/注册(统一接口)
使用手机号和验证码登录
使用手机号和验证码登录,如果用户不存在则自动注册。
注册时需要提供昵称。
"""
# 验证是否同意用户协议和隐私政策
if not request.agreed_to_terms:
@@ -625,13 +631,20 @@ async def login_with_sms(
detail="请先阅读并同意用户协议和隐私政策"
)
# 验证验证码
success, message = await verify_code(
db=db,
phone=request.phone,
code=request.code,
purpose="login"
)
# 验证验证码支持login和register两种目的兼容旧流程
# 先尝试用login验证如果失败再尝试register验证
success = False
message = ""
for purpose in ["login", "register"]:
success, message = await verify_code(
db=db,
phone=request.phone,
code=request.code,
purpose=purpose
)
if success:
break
if not success:
raise HTTPException(
@@ -644,11 +657,34 @@ async def login_with_sms(
result = await db.execute(stmt)
user = result.scalar_one_or_none()
# 如果用户不存在,自动注册
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
# 注册时需要提供昵称
if not request.nickname or len(request.nickname.strip()) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="首次登录需要设置昵称"
)
# 创建新用户
user_id = str(uuid.uuid4())
# 生成一个随机密码(用户不会用到,但数据库要求必填)
random_password = secrets.token_urlsafe(32)
password_hash = hash_password(random_password)
user = User(
id=user_id,
phone=request.phone,
password_hash=password_hash,
email=None,
nickname=request.nickname.strip(),
subscription_type="free",
created_at=datetime.now(timezone.utc)
)
db.add(user)
await db.commit()
await db.refresh(user)
# 创建刷新令牌
refresh_token_str = create_refresh_token()

View File

@@ -21,7 +21,7 @@ async def get_current_book(
db: AsyncSession = Depends(get_async_db)
):
"""获取当前回忆录(需要认证)"""
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc())
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()).limit(1)
result = await db.execute(stmt)
book = result.scalar_one_or_none()
@@ -45,7 +45,7 @@ async def clear_book_update(
db: AsyncSession = Depends(get_async_db),
):
"""清除回忆录更新标记"""
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc())
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()).limit(1)
result = await db.execute(stmt)
book = result.scalar_one_or_none()
if not book:

View File

@@ -166,7 +166,7 @@ async def check_rate_limit(db: AsyncSession, phone: str) -> Tuple[bool, int]:
SmsVerificationCode.phone == phone
).order_by(
SmsVerificationCode.created_at.desc()
)
).limit(1)
result = await db.execute(stmt)
recent_code = result.scalar_one_or_none()
@@ -262,7 +262,7 @@ async def verify_code(
SmsVerificationCode.is_expired == False
).order_by(
SmsVerificationCode.created_at.desc()
)
).limit(1)
result = await db.execute(stmt)
verification = result.scalar_one_or_none()