feat: 添加用户认证功能

- 添加认证路由(注册、登录、刷新令牌、登出、获取用户信息)
- 添加认证服务(密码哈希、JWT令牌生成和验证)
- 添加认证中间件(获取当前用户)
- 支持手机号和密码登录
- 支持访问令牌和刷新令牌机制
This commit is contained in:
徐在坤
2026-01-18 15:57:40 +08:00
parent bf9f3cf363
commit 347fd43b35
9 changed files with 492 additions and 0 deletions

84
api/middleware/auth.py Normal file
View File

@@ -0,0 +1,84 @@
"""
认证依赖从JWT令牌获取当前用户
"""
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_async_db
from database.models import User
from services.auth_service import verify_token
# OAuth2密码流配置
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_async_db)
) -> User:
"""
从JWT令牌获取当前用户
Args:
token: JWT访问令牌
db: 数据库会话
Returns:
当前用户对象
Raises:
HTTPException: 如果令牌无效或用户不存在
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
# 验证令牌
payload = verify_token(token)
if payload is None:
raise credentials_exception
# 获取用户ID
user_id: str = payload.get("sub")
if user_id is None:
raise credentials_exception
# 检查令牌类型
token_type = payload.get("type")
if token_type != "access":
raise credentials_exception
# 从数据库获取用户
user = await db.get(User, user_id)
if user is None:
raise credentials_exception
return user
async def get_optional_user(
token: Optional[str] = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_async_db)
) -> Optional[User]:
"""
可选用户(用于某些公开端点)
Args:
token: JWT访问令牌可选
db: 数据库会话
Returns:
用户对象如果提供了有效令牌否则返回None
"""
if token is None:
return None
try:
return await get_current_user(token, db)
except HTTPException:
return None

298
api/routers/auth.py Normal file
View File

@@ -0,0 +1,298 @@
"""
认证相关 API 路由:注册、登录、刷新令牌、登出
"""
import uuid
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from database import get_async_db
from database.models import User, RefreshToken
from services.auth_service import (
hash_password,
verify_password,
create_access_token,
create_refresh_token,
get_token_expires_at
)
from middleware.auth import get_current_user
router = APIRouter(prefix="/api/auth", tags=["auth"])
# 请求模型
class RegisterRequest(BaseModel):
"""用户注册请求"""
phone: str = Field(..., min_length=11, max_length=11, description="手机号11位")
password: str = Field(..., min_length=6, description="密码至少6位")
nickname: str = Field(..., min_length=1, max_length=50, description="昵称")
email: Optional[str] = Field(None, description="邮箱(可选)")
class LoginRequest(BaseModel):
"""用户登录请求"""
phone: str = Field(..., min_length=11, max_length=11, description="手机号11位")
password: str = Field(..., min_length=1, description="密码")
class RefreshTokenRequest(BaseModel):
"""刷新令牌请求"""
refresh_token: str = Field(..., description="刷新令牌")
# 响应模型
class TokenResponse(BaseModel):
"""令牌响应"""
access_token: str
refresh_token: str
token_type: str = "bearer"
class UserResponse(BaseModel):
"""用户信息响应"""
id: str
phone: str
email: Optional[str]
nickname: str
avatar_url: Optional[str]
subscription_type: str
created_at: str
@router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
async def register(
request: RegisterRequest,
db: AsyncSession = Depends(get_async_db)
):
"""
用户注册
注册成功后返回访问令牌和刷新令牌
"""
# 检查手机号是否已存在
stmt = select(User).where(User.phone == request.phone)
result = await db.execute(stmt)
existing_user = result.scalar_one_or_none()
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该手机号已被注册"
)
# 检查邮箱是否已存在(如果提供了邮箱)
if request.email:
stmt = select(User).where(User.email == request.email)
result = await db.execute(stmt)
existing_email = result.scalar_one_or_none()
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该邮箱已被注册"
)
# 创建新用户
user_id = str(uuid.uuid4())
password_hash = hash_password(request.password)
user = User(
id=user_id,
phone=request.phone,
password_hash=password_hash,
email=request.email,
nickname=request.nickname,
subscription_type="free",
created_at=datetime.now(timezone.utc)
)
db.add(user)
# 创建刷新令牌
refresh_token_str = create_refresh_token()
refresh_token = RefreshToken(
id=str(uuid.uuid4()),
user_id=user_id,
token=refresh_token_str,
expires_at=get_token_expires_at(),
created_at=datetime.now(timezone.utc)
)
db.add(refresh_token)
await db.commit()
await db.refresh(user)
# 生成访问令牌
access_token = create_access_token(data={"sub": user_id})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer"
)
@router.post("/login", response_model=TokenResponse)
async def login(
request: LoginRequest,
db: AsyncSession = Depends(get_async_db)
):
"""
用户登录
验证手机号和密码,返回访问令牌和刷新令牌
"""
# 验证手机号格式(简单验证)
if not request.phone or len(request.phone) != 11 or not request.phone.isdigit():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="手机号格式不正确应为11位数字"
)
# 验证密码不为空
if not request.password:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="密码不能为空"
)
# 查找用户
stmt = select(User).where(User.phone == request.phone)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="手机号或密码错误"
)
# 验证密码
if not verify_password(request.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="手机号或密码错误"
)
# 创建新的刷新令牌
refresh_token_str = create_refresh_token()
refresh_token = RefreshToken(
id=str(uuid.uuid4()),
user_id=user.id,
token=refresh_token_str,
expires_at=get_token_expires_at(),
created_at=datetime.now(timezone.utc)
)
db.add(refresh_token)
await db.commit()
# 生成访问令牌
access_token = create_access_token(data={"sub": user.id})
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer"
)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(
request: RefreshTokenRequest,
db: AsyncSession = Depends(get_async_db)
):
"""
刷新访问令牌
使用刷新令牌获取新的访问令牌
"""
# 查找刷新令牌
stmt = select(RefreshToken).where(RefreshToken.token == request.refresh_token)
result = await db.execute(stmt)
refresh_token = result.scalar_one_or_none()
if not refresh_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的刷新令牌"
)
# 检查是否已撤销
if refresh_token.is_revoked:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="刷新令牌已撤销"
)
# 检查是否过期
if refresh_token.expires_at < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="刷新令牌已过期"
)
# 验证用户是否存在
user = await db.get(User, refresh_token.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
# 生成新的访问令牌
access_token = create_access_token(data={"sub": user.id})
return TokenResponse(
access_token=access_token,
refresh_token=request.refresh_token, # 返回相同的刷新令牌
token_type="bearer"
)
@router.post("/logout", status_code=status.HTTP_200_OK)
async def logout(
request: RefreshTokenRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
):
"""
用户登出
撤销刷新令牌
"""
# 查找刷新令牌
stmt = select(RefreshToken).where(
RefreshToken.token == request.refresh_token,
RefreshToken.user_id == current_user.id
)
result = await db.execute(stmt)
refresh_token = result.scalar_one_or_none()
if refresh_token:
refresh_token.is_revoked = True
await db.commit()
return {"message": "登出成功"}
@router.get("/me", response_model=UserResponse)
async def get_me(
current_user: User = Depends(get_current_user)
):
"""
获取当前用户信息
"""
return UserResponse(
id=current_user.id,
phone=current_user.phone,
email=current_user.email,
nickname=current_user.nickname,
avatar_url=current_user.avatar_url,
subscription_type=current_user.subscription_type,
created_at=current_user.created_at.isoformat()
)

View File

@@ -0,0 +1,110 @@
"""
认证服务模块密码哈希、JWT令牌生成和验证
"""
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict
from jose import JWTError, jwt
from passlib.context import CryptContext
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT配置
SECRET_KEY = os.getenv("SECRET_KEY", secrets.token_urlsafe(32))
ALGORITHM = os.getenv("ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "120")) # 2小时
REFRESH_TOKEN_EXPIRE_DAYS = 30 # 30天
def hash_password(password: str) -> str:
"""
对密码进行哈希加密
Args:
password: 明文密码
Returns:
哈希后的密码
"""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
验证密码
Args:
plain_password: 明文密码
hashed_password: 哈希后的密码
Returns:
是否匹配
"""
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str:
"""
创建访问令牌JWT
Args:
data: 要编码到令牌中的数据通常包含user_id
expires_delta: 过期时间增量如果不提供则使用默认值2小时
Returns:
JWT令牌字符串
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({
"exp": expire,
"type": "access"
})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def create_refresh_token() -> str:
"""
生成刷新令牌(随机字符串)
Returns:
随机生成的刷新令牌字符串
"""
return secrets.token_urlsafe(32)
def verify_token(token: str) -> Optional[Dict]:
"""
验证JWT令牌
Args:
token: JWT令牌字符串
Returns:
解码后的令牌数据如果无效则返回None
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
return None
def get_token_expires_at() -> datetime:
"""
获取刷新令牌的过期时间30天后
Returns:
过期时间
"""
return datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)

Binary file not shown.

After

Width:  |  Height:  |  Size: 492 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB