381 lines
13 KiB
Python
381 lines
13 KiB
Python
import random
|
|
import secrets
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.db import utc_now
|
|
from app.core.security import (
|
|
create_access_token,
|
|
get_token_expires_at,
|
|
hash_password,
|
|
verify_password,
|
|
)
|
|
from app.core.security import (
|
|
create_refresh_token as generate_refresh_token_str,
|
|
)
|
|
from app.features.auth import repo
|
|
from app.features.auth.models import RefreshToken, SmsVerificationCode
|
|
from app.features.user.models import User
|
|
from app.ports.sms import SmsSender
|
|
|
|
CODE_LENGTH = 6
|
|
CODE_EXPIRE_MINUTES = 5
|
|
RATE_LIMIT_SECONDS = 60
|
|
|
|
|
|
class AuthError(Exception):
|
|
def __init__(self, message: str, code: str = "AUTH_ERROR"):
|
|
self.message = message
|
|
self.code = code
|
|
super().__init__(message)
|
|
|
|
|
|
class AuthService:
|
|
def __init__(self, db: AsyncSession, sms: SmsSender):
|
|
self._db = db
|
|
self._sms = sms
|
|
|
|
# ── private helpers ──────────────────────────────────────
|
|
|
|
def _generate_code(self) -> str:
|
|
return "".join(str(random.randint(0, 9)) for _ in range(CODE_LENGTH))
|
|
|
|
async def _verify_sms_code(
|
|
self, phone: str, code: str, purpose: str
|
|
) -> tuple[bool, str]:
|
|
"""Verify SMS code (DB check + mark used). Returns (success, message)."""
|
|
record = await repo.get_latest_unused_code(phone, purpose, self._db)
|
|
if not record:
|
|
return False, "验证码不存在或已使用"
|
|
now = utc_now()
|
|
if now > record.expires_at:
|
|
record.is_expired = True
|
|
await self._db.commit()
|
|
return False, "验证码已过期"
|
|
if record.code != code:
|
|
return False, "验证码错误"
|
|
record.is_used = True
|
|
record.verified_at = now
|
|
await self._db.commit()
|
|
return True, "验证成功"
|
|
|
|
async def send_sms_code(
|
|
self,
|
|
phone: str,
|
|
purpose: str,
|
|
ip_address: str | None = None,
|
|
) -> tuple[bool, str, int]:
|
|
"""Send SMS verification code. Returns (success, message, expires_in_seconds)."""
|
|
if purpose == "register":
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
if purpose == "reset_password":
|
|
if not await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号未注册", "USER_NOT_FOUND")
|
|
recent = await repo.get_recent_code_for_rate_limit(phone, self._db)
|
|
if recent:
|
|
now = utc_now()
|
|
elapsed = (now - recent.created_at).total_seconds()
|
|
if elapsed < RATE_LIMIT_SECONDS:
|
|
remaining = int(RATE_LIMIT_SECONDS - elapsed)
|
|
return False, f"发送过于频繁,请{remaining}秒后再试", 0
|
|
code = self._generate_code()
|
|
if not self._sms.send_verification_code(phone, code):
|
|
return False, "短信发送失败,请稍后重试", 0
|
|
expires_at = utc_now() + timedelta(minutes=CODE_EXPIRE_MINUTES)
|
|
record = SmsVerificationCode(
|
|
id=str(uuid.uuid4()),
|
|
phone=phone,
|
|
code=code,
|
|
purpose=purpose,
|
|
expires_at=expires_at,
|
|
ip_address=ip_address,
|
|
)
|
|
await repo.create_verification_code(record, self._db)
|
|
await self._db.commit()
|
|
return True, "验证码已发送", CODE_EXPIRE_MINUTES * 60
|
|
|
|
async def _issue_tokens(self, user_id: str, device_info: str = "") -> dict:
|
|
"""Create access + refresh token pair and add refresh token to session."""
|
|
refresh_str = generate_refresh_token_str()
|
|
token = RefreshToken(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
token=refresh_str,
|
|
expires_at=get_token_expires_at(),
|
|
created_at=datetime.now(timezone.utc),
|
|
device_info=device_info or None,
|
|
)
|
|
await repo.create_refresh_token(token, self._db)
|
|
access = create_access_token(data={"sub": user_id})
|
|
return {"access_token": access, "refresh_token": refresh_str}
|
|
|
|
# ── public API ───────────────────────────────────────────
|
|
|
|
async def register(
|
|
self,
|
|
phone: str,
|
|
password: str,
|
|
nickname: str,
|
|
email: str | None = None,
|
|
) -> dict:
|
|
"""Register new user. Returns {user, access_token, refresh_token}."""
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
|
|
if email and await repo.get_user_by_email(email, self._db):
|
|
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS")
|
|
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(password),
|
|
email=email,
|
|
nickname=nickname,
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
await repo.create_user(user, self._db)
|
|
tokens = await self._issue_tokens(user_id)
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return {"user": user, **tokens}
|
|
|
|
async def login(
|
|
self,
|
|
phone: str,
|
|
password: str,
|
|
device_info: str = "",
|
|
) -> dict:
|
|
"""Login with phone+password. Returns {user, access_token, refresh_token}."""
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
if not user or not verify_password(password, user.password_hash):
|
|
raise AuthError("手机号或密码错误", "INVALID_CREDENTIALS")
|
|
|
|
tokens = await self._issue_tokens(user.id, device_info)
|
|
await self._db.commit()
|
|
return {"user": user, **tokens}
|
|
|
|
async def refresh_tokens(
|
|
self,
|
|
refresh_token: str,
|
|
device_info: str = "",
|
|
) -> dict:
|
|
"""Refresh access token. Returns {access_token, refresh_token}."""
|
|
token_record = await repo.get_refresh_token_by_token(refresh_token, self._db)
|
|
if not token_record:
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
|
|
if token_record.is_revoked:
|
|
raise AuthError("刷新令牌已撤销", "TOKEN_REVOKED")
|
|
|
|
if token_record.expires_at < datetime.now(timezone.utc):
|
|
raise AuthError("刷新令牌已过期", "TOKEN_EXPIRED")
|
|
|
|
user = await repo.get_user_by_id(token_record.user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
access_token = create_access_token(data={"sub": user.id})
|
|
return {
|
|
"access_token": access_token,
|
|
"refresh_token": refresh_token,
|
|
}
|
|
|
|
async def logout(self, refresh_token: str, user_id: str) -> None:
|
|
"""Revoke a refresh token owned by the given user."""
|
|
token_record = await repo.get_refresh_token_by_token(refresh_token, self._db)
|
|
if token_record and token_record.user_id == user_id:
|
|
token_record.is_revoked = True
|
|
await self._db.commit()
|
|
|
|
async def logout_all(self, user_id: str) -> int:
|
|
"""Revoke all refresh tokens for user. Returns count revoked."""
|
|
tokens = await repo.get_active_tokens_for_user(user_id, self._db)
|
|
for token in tokens:
|
|
token.is_revoked = True
|
|
await self._db.commit()
|
|
return len(tokens)
|
|
|
|
async def login_with_sms(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
) -> dict:
|
|
"""SMS login (auto-register if new). Returns {user, access_token, refresh_token, is_new_user}."""
|
|
success = False
|
|
message = ""
|
|
for purpose in ("login", "register"):
|
|
success, message = await self._verify_sms_code(phone, code, purpose)
|
|
if success:
|
|
break
|
|
|
|
if not success:
|
|
raise AuthError(message, "INVALID_SMS_CODE")
|
|
|
|
return await self._sms_login_after_code_verified(
|
|
phone, device_info=device_info, nickname=nickname
|
|
)
|
|
|
|
async def _sms_login_after_code_verified(
|
|
self,
|
|
phone: str,
|
|
*,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
) -> dict:
|
|
"""SMS 已校验通过后:查找或创建用户并签发令牌。"""
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
is_new_user = user is None
|
|
|
|
if is_new_user:
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(secrets.token_urlsafe(32)),
|
|
nickname=(nickname or "").strip(),
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
await repo.create_user(user, self._db)
|
|
|
|
tokens = await self._issue_tokens(user.id, device_info)
|
|
await self._db.commit()
|
|
if is_new_user:
|
|
await self._db.refresh(user)
|
|
|
|
return {"user": user, "is_new_user": is_new_user, **tokens}
|
|
|
|
async def mock_sms_login(
|
|
self,
|
|
phone: str,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
) -> dict:
|
|
"""跳过短信校验的登录/自动注册(仅由 mock 路由在配置允许时调用)。"""
|
|
return await self._sms_login_after_code_verified(
|
|
phone, device_info=device_info, nickname=nickname
|
|
)
|
|
|
|
async def register_with_sms(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
password: str,
|
|
nickname: str,
|
|
email: str | None = None,
|
|
device_info: str = "",
|
|
) -> dict:
|
|
"""SMS register. Returns {user, access_token, refresh_token}."""
|
|
success, message = await self._verify_sms_code(phone, code, "register")
|
|
if not success:
|
|
raise AuthError(message, "INVALID_SMS_CODE")
|
|
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
|
|
if email and await repo.get_user_by_email(email, self._db):
|
|
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS")
|
|
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(password),
|
|
email=email,
|
|
nickname=nickname,
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
await repo.create_user(user, self._db)
|
|
tokens = await self._issue_tokens(user_id, device_info)
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return {"user": user, **tokens}
|
|
|
|
async def reset_password(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
new_password: str,
|
|
) -> None:
|
|
"""Reset password via SMS code."""
|
|
success, message = await self._verify_sms_code(phone, code, "reset_password")
|
|
if not success:
|
|
raise AuthError(message, "INVALID_SMS_CODE")
|
|
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
user.password_hash = hash_password(new_password)
|
|
await self._db.commit()
|
|
|
|
async def change_password(
|
|
self,
|
|
user_id: str,
|
|
old_password: str,
|
|
new_password: str,
|
|
) -> None:
|
|
"""Change password (requires old password)."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
if not verify_password(old_password, user.password_hash):
|
|
raise AuthError("旧密码错误", "WRONG_PASSWORD")
|
|
|
|
user.password_hash = hash_password(new_password)
|
|
await self._db.commit()
|
|
|
|
async def change_phone(
|
|
self,
|
|
user_id: str,
|
|
new_phone: str,
|
|
code: str,
|
|
) -> User:
|
|
"""Change phone number via SMS code. Returns updated user."""
|
|
success, message = await self._verify_sms_code(new_phone, code, "change_phone")
|
|
if not success:
|
|
raise AuthError(message, "INVALID_SMS_CODE")
|
|
|
|
existing = await repo.get_user_by_phone(new_phone, self._db)
|
|
if existing and existing.id != user_id:
|
|
raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN")
|
|
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
user.phone = new_phone
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return user
|
|
|
|
async def update_nickname(self, user_id: str, nickname: str) -> User:
|
|
"""Update user nickname."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
user.nickname = nickname.strip()
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return user
|
|
|
|
async def update_avatar_url(self, user_id: str, avatar_url: str) -> User:
|
|
"""Update user avatar URL."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
user.avatar_url = avatar_url
|
|
await self._db.commit()
|
|
await self._db.refresh(user)
|
|
return user
|