refactor: 优化后端认证和路由功能
- 优化auth.py认证路由 - 优化books.py书籍路由 - 优化sms_service.py短信服务
This commit is contained in:
@@ -6,6 +6,7 @@ import io
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
import secrets
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -63,10 +64,11 @@ class SmsRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SmsLoginRequest(BaseModel):
|
class SmsLoginRequest(BaseModel):
|
||||||
"""验证码登录请求"""
|
"""验证码登录/注册请求(统一接口)"""
|
||||||
phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)")
|
phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)")
|
||||||
code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)")
|
code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)")
|
||||||
agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策")
|
agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策")
|
||||||
|
nickname: Optional[str] = Field(None, max_length=50, description="昵称(注册时必填,登录时可选)")
|
||||||
|
|
||||||
|
|
||||||
class SmsRegisterRequest(BaseModel):
|
class SmsRegisterRequest(BaseModel):
|
||||||
@@ -539,8 +541,8 @@ async def send_sms_code(
|
|||||||
发送短信验证码
|
发送短信验证码
|
||||||
|
|
||||||
用途:
|
用途:
|
||||||
- register: 注册
|
- login: 登录/注册(统一接口,用户不存在时自动注册)
|
||||||
- login: 登录
|
- register: 注册(保留兼容,建议使用login)
|
||||||
- reset_password: 重置密码
|
- reset_password: 重置密码
|
||||||
- change_phone: 修改手机号
|
- change_phone: 修改手机号
|
||||||
"""
|
"""
|
||||||
@@ -570,8 +572,8 @@ async def send_sms_code(
|
|||||||
detail="该手机号已被注册"
|
detail="该手机号已被注册"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 对于登录和重置密码,检查手机号是否存在
|
# 对于重置密码,检查手机号是否存在
|
||||||
if request.purpose in ["login", "reset_password"]:
|
if request.purpose == "reset_password":
|
||||||
stmt = select(User).where(User.phone == request.phone)
|
stmt = select(User).where(User.phone == request.phone)
|
||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
@@ -581,6 +583,9 @@ async def send_sms_code(
|
|||||||
detail="该手机号未注册"
|
detail="该手机号未注册"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 对于登录,不再检查用户是否存在,允许自动注册
|
||||||
|
# 发送验证码时统一使用login目的,登录接口会自动处理注册逻辑
|
||||||
|
|
||||||
# 发送验证码
|
# 发送验证码
|
||||||
success, message, expires_in = await send_verification_code(
|
success, message, expires_in = await send_verification_code(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -614,9 +619,10 @@ async def login_with_sms(
|
|||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
验证码登录
|
验证码登录/注册(统一接口)
|
||||||
|
|
||||||
使用手机号和验证码登录
|
使用手机号和验证码登录,如果用户不存在则自动注册。
|
||||||
|
注册时需要提供昵称。
|
||||||
"""
|
"""
|
||||||
# 验证是否同意用户协议和隐私政策
|
# 验证是否同意用户协议和隐私政策
|
||||||
if not request.agreed_to_terms:
|
if not request.agreed_to_terms:
|
||||||
@@ -625,13 +631,20 @@ async def login_with_sms(
|
|||||||
detail="请先阅读并同意用户协议和隐私政策"
|
detail="请先阅读并同意用户协议和隐私政策"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证验证码
|
# 验证验证码(支持login和register两种目的,兼容旧流程)
|
||||||
|
# 先尝试用login验证,如果失败再尝试register验证
|
||||||
|
success = False
|
||||||
|
message = ""
|
||||||
|
|
||||||
|
for purpose in ["login", "register"]:
|
||||||
success, message = await verify_code(
|
success, message = await verify_code(
|
||||||
db=db,
|
db=db,
|
||||||
phone=request.phone,
|
phone=request.phone,
|
||||||
code=request.code,
|
code=request.code,
|
||||||
purpose="login"
|
purpose=purpose
|
||||||
)
|
)
|
||||||
|
if success:
|
||||||
|
break
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -644,12 +657,35 @@ async def login_with_sms(
|
|||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# 如果用户不存在,自动注册
|
||||||
if not user:
|
if not user:
|
||||||
|
# 注册时需要提供昵称
|
||||||
|
if not request.nickname or len(request.nickname.strip()) == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="用户不存在"
|
detail="首次登录需要设置昵称"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建新用户
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
# 生成一个随机密码(用户不会用到,但数据库要求必填)
|
||||||
|
random_password = secrets.token_urlsafe(32)
|
||||||
|
password_hash = hash_password(random_password)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=user_id,
|
||||||
|
phone=request.phone,
|
||||||
|
password_hash=password_hash,
|
||||||
|
email=None,
|
||||||
|
nickname=request.nickname.strip(),
|
||||||
|
subscription_type="free",
|
||||||
|
created_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(user)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
# 创建刷新令牌
|
# 创建刷新令牌
|
||||||
refresh_token_str = create_refresh_token()
|
refresh_token_str = create_refresh_token()
|
||||||
refresh_token = RefreshToken(
|
refresh_token = RefreshToken(
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ async def get_current_book(
|
|||||||
db: AsyncSession = Depends(get_async_db)
|
db: AsyncSession = Depends(get_async_db)
|
||||||
):
|
):
|
||||||
"""获取当前回忆录(需要认证)"""
|
"""获取当前回忆录(需要认证)"""
|
||||||
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc())
|
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()).limit(1)
|
||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
book = result.scalar_one_or_none()
|
book = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ async def clear_book_update(
|
|||||||
db: AsyncSession = Depends(get_async_db),
|
db: AsyncSession = Depends(get_async_db),
|
||||||
):
|
):
|
||||||
"""清除回忆录更新标记"""
|
"""清除回忆录更新标记"""
|
||||||
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc())
|
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()).limit(1)
|
||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
book = result.scalar_one_or_none()
|
book = result.scalar_one_or_none()
|
||||||
if not book:
|
if not book:
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ async def check_rate_limit(db: AsyncSession, phone: str) -> Tuple[bool, int]:
|
|||||||
SmsVerificationCode.phone == phone
|
SmsVerificationCode.phone == phone
|
||||||
).order_by(
|
).order_by(
|
||||||
SmsVerificationCode.created_at.desc()
|
SmsVerificationCode.created_at.desc()
|
||||||
)
|
).limit(1)
|
||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
recent_code = result.scalar_one_or_none()
|
recent_code = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -262,7 +262,7 @@ async def verify_code(
|
|||||||
SmsVerificationCode.is_expired == False
|
SmsVerificationCode.is_expired == False
|
||||||
).order_by(
|
).order_by(
|
||||||
SmsVerificationCode.created_at.desc()
|
SmsVerificationCode.created_at.desc()
|
||||||
)
|
).limit(1)
|
||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
verification = result.scalar_one_or_none()
|
verification = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user