import random import secrets import uuid from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from app.core.db import utc_now from app.core.security import ( create_access_token, create_refresh_token as generate_refresh_token_str, get_token_expires_at, hash_password, verify_password, ) from app.features.auth import repo from app.features.auth.models import RefreshToken, SmsVerificationCode from app.features.user.models import User from app.ports.sms import SmsSender CODE_LENGTH = 6 CODE_EXPIRE_MINUTES = 5 RATE_LIMIT_SECONDS = 60 class AuthError(Exception): def __init__(self, message: str, code: str = "AUTH_ERROR"): self.message = message self.code = code super().__init__(message) class AuthService: def __init__(self, db: AsyncSession, sms: SmsSender): self._db = db self._sms = sms # ── private helpers ────────────────────────────────────── def _generate_code(self) -> str: return "".join(str(random.randint(0, 9)) for _ in range(CODE_LENGTH)) async def _verify_sms_code( self, phone: str, code: str, purpose: str ) -> tuple[bool, str]: """Verify SMS code (DB check + mark used). Returns (success, message).""" record = await repo.get_latest_unused_code(phone, purpose, self._db) if not record: return False, "验证码不存在或已使用" now = utc_now() if now > record.expires_at: record.is_expired = True await self._db.commit() return False, "验证码已过期" if record.code != code: return False, "验证码错误" record.is_used = True record.verified_at = now await self._db.commit() return True, "验证成功" async def send_sms_code( self, phone: str, purpose: str, ip_address: str | None = None, ) -> tuple[bool, str, int]: """Send SMS verification code. Returns (success, message, expires_in_seconds).""" if purpose == "register": if await repo.get_user_by_phone(phone, self._db): raise AuthError("该手机号已被注册", "PHONE_EXISTS") if purpose == "reset_password": if not await repo.get_user_by_phone(phone, self._db): raise AuthError("该手机号未注册", "USER_NOT_FOUND") recent = await repo.get_recent_code_for_rate_limit(phone, self._db) if recent: now = utc_now() elapsed = (now - recent.created_at).total_seconds() if elapsed < RATE_LIMIT_SECONDS: remaining = int(RATE_LIMIT_SECONDS - elapsed) return False, f"发送过于频繁,请{remaining}秒后再试", 0 code = self._generate_code() if not self._sms.send_verification_code(phone, code): return False, "短信发送失败,请稍后重试", 0 expires_at = utc_now() + timedelta(minutes=CODE_EXPIRE_MINUTES) record = SmsVerificationCode( id=str(uuid.uuid4()), phone=phone, code=code, purpose=purpose, expires_at=expires_at, ip_address=ip_address, ) await repo.create_verification_code(record, self._db) await self._db.commit() return True, "验证码已发送", CODE_EXPIRE_MINUTES * 60 async def _issue_tokens(self, user_id: str, device_info: str = "") -> dict: """Create access + refresh token pair and add refresh token to session.""" refresh_str = generate_refresh_token_str() token = RefreshToken( id=str(uuid.uuid4()), user_id=user_id, token=refresh_str, expires_at=get_token_expires_at(), created_at=datetime.now(timezone.utc), device_info=device_info or None, ) await repo.create_refresh_token(token, self._db) access = create_access_token(data={"sub": user_id}) return {"access_token": access, "refresh_token": refresh_str} # ── public API ─────────────────────────────────────────── async def register( self, phone: str, password: str, nickname: str, email: str | None = None, ) -> dict: """Register new user. Returns {user, access_token, refresh_token}.""" if await repo.get_user_by_phone(phone, self._db): raise AuthError("该手机号已被注册", "PHONE_EXISTS") if email and await repo.get_user_by_email(email, self._db): raise AuthError("该邮箱已被注册", "EMAIL_EXISTS") user_id = str(uuid.uuid4()) user = User( id=user_id, phone=phone, password_hash=hash_password(password), email=email, nickname=nickname, subscription_type="free", created_at=datetime.now(timezone.utc), ) await repo.create_user(user, self._db) tokens = await self._issue_tokens(user_id) await self._db.commit() await self._db.refresh(user) return {"user": user, **tokens} async def login( self, phone: str, password: str, device_info: str = "", ) -> dict: """Login with phone+password. Returns {user, access_token, refresh_token}.""" user = await repo.get_user_by_phone(phone, self._db) if not user or not verify_password(password, user.password_hash): raise AuthError("手机号或密码错误", "INVALID_CREDENTIALS") tokens = await self._issue_tokens(user.id, device_info) await self._db.commit() return {"user": user, **tokens} async def refresh_tokens( self, refresh_token: str, device_info: str = "", ) -> dict: """Refresh access token. Returns {access_token, refresh_token}.""" token_record = await repo.get_refresh_token_by_token( refresh_token, self._db ) if not token_record: raise AuthError("无效的刷新令牌", "INVALID_TOKEN") if token_record.is_revoked: raise AuthError("刷新令牌已撤销", "TOKEN_REVOKED") if token_record.expires_at < datetime.now(timezone.utc): raise AuthError("刷新令牌已过期", "TOKEN_EXPIRED") user = await repo.get_user_by_id(token_record.user_id, self._db) if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") access_token = create_access_token(data={"sub": user.id}) return { "access_token": access_token, "refresh_token": refresh_token, } async def logout(self, refresh_token: str, user_id: str) -> None: """Revoke a refresh token owned by the given user.""" token_record = await repo.get_refresh_token_by_token( refresh_token, self._db ) if token_record and token_record.user_id == user_id: token_record.is_revoked = True await self._db.commit() async def logout_all(self, user_id: str) -> int: """Revoke all refresh tokens for user. Returns count revoked.""" tokens = await repo.get_active_tokens_for_user(user_id, self._db) for token in tokens: token.is_revoked = True await self._db.commit() return len(tokens) async def login_with_sms( self, phone: str, code: str, device_info: str = "", nickname: str | None = None, ) -> dict: """SMS login (auto-register if new). Returns {user, access_token, refresh_token, is_new_user}.""" success = False message = "" for purpose in ("login", "register"): success, message = await self._verify_sms_code(phone, code, purpose) if success: break if not success: raise AuthError(message, "INVALID_SMS_CODE") user = await repo.get_user_by_phone(phone, self._db) is_new_user = user is None if is_new_user: user_id = str(uuid.uuid4()) user = User( id=user_id, phone=phone, password_hash=hash_password(secrets.token_urlsafe(32)), nickname=(nickname or "").strip(), subscription_type="free", created_at=datetime.now(timezone.utc), ) await repo.create_user(user, self._db) tokens = await self._issue_tokens(user.id, device_info) await self._db.commit() if is_new_user: await self._db.refresh(user) return {"user": user, "is_new_user": is_new_user, **tokens} async def register_with_sms( self, phone: str, code: str, password: str, nickname: str, email: str | None = None, device_info: str = "", ) -> dict: """SMS register. Returns {user, access_token, refresh_token}.""" success, message = await self._verify_sms_code(phone, code, "register") if not success: raise AuthError(message, "INVALID_SMS_CODE") if await repo.get_user_by_phone(phone, self._db): raise AuthError("该手机号已被注册", "PHONE_EXISTS") if email and await repo.get_user_by_email(email, self._db): raise AuthError("该邮箱已被注册", "EMAIL_EXISTS") user_id = str(uuid.uuid4()) user = User( id=user_id, phone=phone, password_hash=hash_password(password), email=email, nickname=nickname, subscription_type="free", created_at=datetime.now(timezone.utc), ) await repo.create_user(user, self._db) tokens = await self._issue_tokens(user_id, device_info) await self._db.commit() await self._db.refresh(user) return {"user": user, **tokens} async def reset_password( self, phone: str, code: str, new_password: str, ) -> None: """Reset password via SMS code.""" success, message = await self._verify_sms_code( phone, code, "reset_password" ) if not success: raise AuthError(message, "INVALID_SMS_CODE") user = await repo.get_user_by_phone(phone, self._db) if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") user.password_hash = hash_password(new_password) await self._db.commit() async def change_password( self, user_id: str, old_password: str, new_password: str, ) -> None: """Change password (requires old password).""" user = await repo.get_user_by_id(user_id, self._db) if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") if not verify_password(old_password, user.password_hash): raise AuthError("旧密码错误", "WRONG_PASSWORD") user.password_hash = hash_password(new_password) await self._db.commit() async def change_phone( self, user_id: str, new_phone: str, code: str, ) -> User: """Change phone number via SMS code. Returns updated user.""" success, message = await self._verify_sms_code( new_phone, code, "change_phone" ) if not success: raise AuthError(message, "INVALID_SMS_CODE") existing = await repo.get_user_by_phone(new_phone, self._db) if existing and existing.id != user_id: raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN") user = await repo.get_user_by_id(user_id, self._db) if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") user.phone = new_phone await self._db.commit() await self._db.refresh(user) return user async def update_nickname(self, user_id: str, nickname: str) -> User: """Update user nickname.""" user = await repo.get_user_by_id(user_id, self._db) if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") user.nickname = nickname.strip() await self._db.commit() await self._db.refresh(user) return user async def update_avatar_url(self, user_id: str, avatar_url: str) -> User: """Update user avatar URL.""" user = await repo.get_user_by_id(user_id, self._db) if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") user.avatar_url = avatar_url await self._db.commit() await self._db.refresh(user) return user