配置 SSOT(TOML + .env) 统一错误契约 Auth 与事务边界 Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client 可观测性(OpenTelemetry + LGTM)
812 lines
30 KiB
Python
812 lines
30 KiB
Python
import asyncio
|
|
import io
|
|
import random
|
|
import secrets
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from PIL import Image
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.cos_url_keys import (
|
|
best_effort_delete_cos_object_for_url,
|
|
extract_cos_object_key_if_owned,
|
|
)
|
|
from app.core.db import transactional, transactional_nested, utc_now
|
|
from app.core.errors import (
|
|
AppError,
|
|
BadRequestError,
|
|
ProviderError,
|
|
RateLimitedError,
|
|
ServiceUnavailableError,
|
|
)
|
|
from app.core.logging import get_logger
|
|
from app.core.security import (
|
|
create_access_token,
|
|
get_token_expires_at,
|
|
hash_password,
|
|
verify_password,
|
|
)
|
|
from app.core.security import (
|
|
create_refresh_token as generate_refresh_token_str,
|
|
)
|
|
from app.features.auth import repo
|
|
from app.features.auth.integrity import is_user_phone_unique_violation, user_integrity_auth_code
|
|
from app.features.auth.models import RefreshToken, SmsVerificationCode
|
|
from app.features.user.models import User
|
|
from app.ports.sms import SmsSender
|
|
from app.ports.storage import ObjectStorage
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
CODE_LENGTH = 6
|
|
CODE_EXPIRE_MINUTES = 5
|
|
RATE_LIMIT_SECONDS = 60
|
|
_SMS_CONSUME_FALLBACK = "验证码不存在或已使用"
|
|
|
|
|
|
def _sms_is_configured() -> bool:
|
|
return bool(
|
|
(settings.tencent_secret_id or "").strip()
|
|
and (settings.tencent_secret_key or "").strip()
|
|
and (settings.tencent_sms_sdk_app_id or "").strip()
|
|
and (settings.tencent_sms_sign_name or "").strip()
|
|
and (settings.tencent_sms_template_id or "").strip()
|
|
)
|
|
|
|
|
|
_VALID_LANGUAGES = {"zh", "en"}
|
|
|
|
|
|
def _normalize_language(lang: str | None) -> str:
|
|
"""Normalize device language token; default to zh on missing/unknown."""
|
|
if not lang:
|
|
return "zh"
|
|
s = str(lang).strip().lower()
|
|
return s if s in _VALID_LANGUAGES else "zh"
|
|
|
|
|
|
def _as_utc(dt: datetime) -> datetime:
|
|
"""Normalize DB datetimes for safe comparison (sqlite may return naive)."""
|
|
if dt.tzinfo is None:
|
|
return dt.replace(tzinfo=timezone.utc)
|
|
return dt.astimezone(timezone.utc)
|
|
|
|
|
|
_AUTH_CODE_MAP: dict[str, tuple[int, str]] = {
|
|
"INVALID_CREDENTIALS": (401, "AUTHENTICATION_FAILED"),
|
|
"INVALID_TOKEN": (401, "AUTHENTICATION_FAILED"),
|
|
"TOKEN_REVOKED": (401, "AUTHENTICATION_FAILED"),
|
|
"TOKEN_EXPIRED": (401, "AUTHENTICATION_FAILED"),
|
|
"REFRESH_TOKEN_REUSE": (401, "REFRESH_TOKEN_REUSE"),
|
|
"USER_NOT_FOUND": (404, "NOT_FOUND"),
|
|
"PHONE_EXISTS": (400, "PHONE_EXISTS"),
|
|
"EMAIL_EXISTS": (400, "EMAIL_EXISTS"),
|
|
"PHONE_TAKEN": (409, "PHONE_TAKEN"),
|
|
"INVALID_SMS_CODE": (400, "INVALID_SMS_CODE"),
|
|
"WRONG_PASSWORD": (400, "WRONG_PASSWORD"),
|
|
"AUTH_ERROR": (400, "BAD_REQUEST"),
|
|
}
|
|
|
|
|
|
class AuthError(AppError):
|
|
def __init__(self, message: str, code: str = "AUTH_ERROR"):
|
|
status_code, error_code = _AUTH_CODE_MAP.get(code, (400, code))
|
|
super().__init__(message, status_code=status_code, error_code=error_code)
|
|
self.code = code
|
|
|
|
|
|
def _raise_auth_error_from_user_integrity(
|
|
exc: IntegrityError,
|
|
*,
|
|
phone_conflict: str,
|
|
) -> None:
|
|
code = user_integrity_auth_code(exc, phone_conflict=phone_conflict)
|
|
if code == "PHONE_EXISTS":
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS") from exc
|
|
if code == "EMAIL_EXISTS":
|
|
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS") from exc
|
|
if code == "PHONE_TAKEN":
|
|
raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN") from exc
|
|
raise exc
|
|
|
|
|
|
async def _create_user_with_integrity_check(
|
|
db: AsyncSession,
|
|
user: User,
|
|
*,
|
|
phone_conflict: str,
|
|
) -> None:
|
|
await repo.create_user(user, db)
|
|
try:
|
|
await db.flush()
|
|
except IntegrityError as exc:
|
|
_raise_auth_error_from_user_integrity(exc, phone_conflict=phone_conflict)
|
|
|
|
|
|
def _public_tokens(issued: dict) -> dict:
|
|
"""Strip internal fields before returning tokens to callers."""
|
|
return {
|
|
"access_token": issued["access_token"],
|
|
"refresh_token": issued["refresh_token"],
|
|
}
|
|
|
|
|
|
class AuthService:
|
|
def __init__(
|
|
self,
|
|
db: AsyncSession,
|
|
sms: SmsSender,
|
|
*,
|
|
object_storage: ObjectStorage | None = None,
|
|
):
|
|
self._db = db
|
|
self._sms = sms
|
|
self._object_storage = object_storage
|
|
|
|
# ── private helpers ──────────────────────────────────────
|
|
|
|
def _generate_code(self) -> str:
|
|
return "".join(str(random.randint(0, 9)) for _ in range(CODE_LENGTH))
|
|
|
|
async def _check_sms_code(
|
|
self, phone: str, code: str, purpose: str
|
|
) -> tuple[SmsVerificationCode | None, str]:
|
|
"""Validate SMS code without consuming it. Returns (record, message).
|
|
|
|
UX pre-check only; authoritative validation is ``try_consume_verification_code``
|
|
inside a transaction.
|
|
"""
|
|
record = await repo.get_latest_unused_code(phone, purpose, self._db)
|
|
if not record:
|
|
return None, _SMS_CONSUME_FALLBACK
|
|
now = utc_now()
|
|
if now > record.expires_at:
|
|
async with transactional(self._db):
|
|
record.is_expired = True
|
|
return None, "验证码已过期"
|
|
if record.code != code:
|
|
return None, "验证码错误"
|
|
return record, "验证成功"
|
|
|
|
async def _precheck_sms_code(
|
|
self, phone: str, code: str, purpose: str
|
|
) -> str | None:
|
|
"""UX pre-check: fast-fail without consuming. None means likely valid."""
|
|
record, message = await self._check_sms_code(phone, code, purpose)
|
|
if record is None:
|
|
return message
|
|
return None
|
|
|
|
async def _precheck_sms_code_for_purposes(
|
|
self, phone: str, code: str, purposes: tuple[str, ...]
|
|
) -> str | None:
|
|
"""Login flow: try each purpose until one pre-check passes."""
|
|
last_message = _SMS_CONSUME_FALLBACK
|
|
for purpose in purposes:
|
|
record, message = await self._check_sms_code(phone, code, purpose)
|
|
if record is not None:
|
|
return None
|
|
last_message = message
|
|
return last_message
|
|
|
|
async def _sms_invalid_message_after_consume_failure(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
purposes: tuple[str, ...],
|
|
*,
|
|
fallback: str = _SMS_CONSUME_FALLBACK,
|
|
) -> str:
|
|
"""Re-read code state after atomic consume failed (race/expiry/concurrency)."""
|
|
for purpose in purposes:
|
|
record, message = await self._check_sms_code(phone, code, purpose)
|
|
if record is None:
|
|
return message
|
|
return fallback
|
|
|
|
async def _consume_sms_code_or_raise(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
purpose: str,
|
|
*,
|
|
purposes: tuple[str, ...] | None = None,
|
|
) -> SmsVerificationCode:
|
|
"""Atomically consume SMS code inside ``transactional()``; raise on failure."""
|
|
purposes_to_try = purposes or (purpose,)
|
|
for p in purposes_to_try:
|
|
consumed = await repo.try_consume_verification_code(
|
|
phone, code, p, self._db
|
|
)
|
|
if consumed is not None:
|
|
return consumed
|
|
message = await self._sms_invalid_message_after_consume_failure(
|
|
phone, code, purposes_to_try
|
|
)
|
|
raise AuthError(message, "INVALID_SMS_CODE")
|
|
|
|
async def send_sms_code(
|
|
self,
|
|
phone: str,
|
|
purpose: str,
|
|
ip_address: str | None = None,
|
|
) -> tuple[bool, str, int]:
|
|
"""Send SMS verification code. Returns (True, message, expires_in_seconds) on success."""
|
|
if not _sms_is_configured():
|
|
raise ServiceUnavailableError("短信服务未配置,请稍后再试")
|
|
if purpose == "register":
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
if purpose == "reset_password":
|
|
if not await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号未注册", "USER_NOT_FOUND")
|
|
recent = await repo.get_recent_code_for_rate_limit(phone, self._db)
|
|
if recent:
|
|
now = utc_now()
|
|
elapsed = (now - recent.created_at).total_seconds()
|
|
if elapsed < RATE_LIMIT_SECONDS:
|
|
remaining = int(RATE_LIMIT_SECONDS - elapsed)
|
|
raise RateLimitedError(f"发送过于频繁,请{remaining}秒后再试")
|
|
code = self._generate_code()
|
|
expires_at = utc_now() + timedelta(minutes=CODE_EXPIRE_MINUTES)
|
|
record = SmsVerificationCode(
|
|
id=str(uuid.uuid4()),
|
|
phone=phone,
|
|
code=code,
|
|
purpose=purpose,
|
|
expires_at=expires_at,
|
|
ip_address=ip_address,
|
|
)
|
|
async with transactional(self._db):
|
|
await repo.create_verification_code(record, self._db)
|
|
|
|
if not self._sms.send_verification_code(phone, code):
|
|
async with transactional(self._db):
|
|
await repo.mark_verification_code_expired(record.id, self._db)
|
|
raise ProviderError("短信发送失败,请稍后重试")
|
|
return True, "验证码已发送", CODE_EXPIRE_MINUTES * 60
|
|
|
|
async def _issue_tokens(self, user_id: str, device_info: str = "") -> dict:
|
|
"""Create access + refresh token pair and add refresh token to session."""
|
|
refresh_str = generate_refresh_token_str()
|
|
token = RefreshToken(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
token=refresh_str,
|
|
expires_at=get_token_expires_at(),
|
|
created_at=datetime.now(timezone.utc),
|
|
device_info=device_info or None,
|
|
)
|
|
await repo.create_refresh_token(token, self._db)
|
|
access = create_access_token(data={"sub": user_id})
|
|
return {
|
|
"access_token": access,
|
|
"refresh_token": refresh_str,
|
|
"refresh_token_id": token.id,
|
|
}
|
|
|
|
async def _try_idempotent_refresh_within_grace(
|
|
self,
|
|
token_record: RefreshToken,
|
|
) -> dict | None:
|
|
"""Grace-window retry: return new access + existing replacement refresh."""
|
|
grace = settings.refresh_token_reuse_grace_seconds
|
|
if grace <= 0:
|
|
return None
|
|
if not token_record.replaced_by_token_id or token_record.rotated_at is None:
|
|
return None
|
|
now = utc_now()
|
|
rotated_at = _as_utc(token_record.rotated_at)
|
|
if rotated_at + timedelta(seconds=grace) <= now:
|
|
return None
|
|
replacement = await repo.get_refresh_token_by_id(
|
|
token_record.replaced_by_token_id, self._db
|
|
)
|
|
if replacement is None or replacement.is_revoked:
|
|
return None
|
|
if _as_utc(replacement.expires_at) < now:
|
|
return None
|
|
user = await repo.get_user_by_id(replacement.user_id, self._db)
|
|
if user is None:
|
|
return None
|
|
access = create_access_token(data={"sub": user.id})
|
|
return {
|
|
"access_token": access,
|
|
"refresh_token": replacement.token,
|
|
}
|
|
|
|
# ── public API ───────────────────────────────────────────
|
|
|
|
async def register(
|
|
self,
|
|
phone: str,
|
|
password: str,
|
|
nickname: str,
|
|
email: str | None = None,
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""Register new user. Returns {user, access_token, refresh_token}."""
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
|
|
if email and await repo.get_user_by_email(email, self._db):
|
|
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS")
|
|
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(password),
|
|
email=email,
|
|
nickname=nickname,
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
language_preference=_normalize_language(language),
|
|
)
|
|
async with transactional(self._db):
|
|
await _create_user_with_integrity_check(
|
|
self._db, user, phone_conflict="PHONE_EXISTS"
|
|
)
|
|
tokens = await self._issue_tokens(user_id)
|
|
await self._db.refresh(user)
|
|
return {"user": user, **_public_tokens(tokens)}
|
|
|
|
async def login(
|
|
self,
|
|
phone: str,
|
|
password: str,
|
|
device_info: str = "",
|
|
) -> dict:
|
|
"""Login with phone+password. Returns {user, access_token, refresh_token}."""
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
if not user or not verify_password(password, user.password_hash):
|
|
raise AuthError("手机号或密码错误", "INVALID_CREDENTIALS")
|
|
|
|
async with transactional(self._db):
|
|
tokens = await self._issue_tokens(user.id, device_info)
|
|
return {"user": user, **_public_tokens(tokens)}
|
|
|
|
async def _revoke_all_active_tokens_in_session(self, user_id: str) -> int:
|
|
"""Revoke all active refresh tokens on the current session (no commit)."""
|
|
tokens = await repo.get_active_tokens_for_user(user_id, self._db)
|
|
for token in tokens:
|
|
token.is_revoked = True
|
|
return len(tokens)
|
|
|
|
async def refresh_tokens(
|
|
self,
|
|
refresh_token: str,
|
|
device_info: str = "",
|
|
) -> dict:
|
|
"""Rotate refresh token and issue a new access token pair."""
|
|
reuse_detected = False
|
|
async with transactional(self._db):
|
|
consumed = await repo.try_consume_refresh_token(refresh_token, self._db)
|
|
if consumed is not None:
|
|
user = await repo.get_user_by_id(consumed.user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
issued = await self._issue_tokens(
|
|
consumed.user_id,
|
|
device_info or (consumed.device_info or ""),
|
|
)
|
|
await repo.link_refresh_rotation(
|
|
consumed.id,
|
|
issued["refresh_token_id"],
|
|
utc_now(),
|
|
self._db,
|
|
)
|
|
return _public_tokens(issued)
|
|
|
|
token_record = await repo.get_refresh_token_by_token(
|
|
refresh_token, self._db
|
|
)
|
|
if not token_record:
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
if token_record.is_revoked:
|
|
idempotent = await self._try_idempotent_refresh_within_grace(
|
|
token_record
|
|
)
|
|
if idempotent is None and not token_record.replaced_by_token_id:
|
|
# Concurrent refresh may observe revoke before lineage commits.
|
|
token_record = (
|
|
await repo.get_refresh_token_by_token(
|
|
refresh_token, self._db
|
|
)
|
|
or token_record
|
|
)
|
|
idempotent = await self._try_idempotent_refresh_within_grace(
|
|
token_record
|
|
)
|
|
if idempotent is not None:
|
|
return idempotent
|
|
grace = settings.refresh_token_reuse_grace_seconds
|
|
rotated_at = token_record.rotated_at
|
|
still_in_grace = (
|
|
grace > 0
|
|
and rotated_at is not None
|
|
and _as_utc(rotated_at) + timedelta(seconds=grace) > utc_now()
|
|
)
|
|
if still_in_grace:
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
if (
|
|
grace > 0
|
|
and rotated_at is None
|
|
and not token_record.replaced_by_token_id
|
|
):
|
|
# Revoke visible but lineage not committed yet (concurrent rotation).
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
logger.bind(user_id=token_record.user_id).warning(
|
|
"Refresh token reuse detected (grace expired or no lineage)"
|
|
)
|
|
await self._revoke_all_active_tokens_in_session(token_record.user_id)
|
|
reuse_detected = True
|
|
elif _as_utc(token_record.expires_at) < utc_now():
|
|
raise AuthError("刷新令牌已过期", "TOKEN_EXPIRED")
|
|
else:
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
|
|
if reuse_detected:
|
|
raise AuthError("刷新令牌已失效,请重新登录", "REFRESH_TOKEN_REUSE")
|
|
|
|
async def logout(self, refresh_token: str, user_id: str) -> None:
|
|
"""Revoke a refresh token owned by the given user."""
|
|
token_record = await repo.get_refresh_token_by_token(refresh_token, self._db)
|
|
if token_record and token_record.user_id == user_id:
|
|
async with transactional(self._db):
|
|
token_record.is_revoked = True
|
|
|
|
async def logout_all(self, user_id: str) -> int:
|
|
"""Revoke all refresh tokens for user. Returns count revoked."""
|
|
async with transactional(self._db):
|
|
return await self._revoke_all_active_tokens_in_session(user_id)
|
|
|
|
async def login_with_sms(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""SMS login (auto-register if new). Returns {user, access_token, refresh_token, is_new_user}."""
|
|
precheck_error = await self._precheck_sms_code_for_purposes(
|
|
phone, code, ("login", "register")
|
|
)
|
|
if precheck_error is not None:
|
|
raise AuthError(precheck_error, "INVALID_SMS_CODE")
|
|
|
|
return await self._sms_login_after_code_verified(
|
|
phone,
|
|
code=code,
|
|
device_info=device_info,
|
|
nickname=nickname,
|
|
language=language,
|
|
)
|
|
|
|
async def _sms_login_after_code_verified(
|
|
self,
|
|
phone: str,
|
|
*,
|
|
code: str | None = None,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""SMS 校验通过后:在同一事务内原子消耗验证码、创建用户(如需)并签发令牌。
|
|
|
|
``language`` 仅在「新用户」分支下写入;命中已有用户时不覆盖偏好。
|
|
mock 路由传入 ``code=None`` 跳过验证码消耗。
|
|
"""
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
is_new_user = user is None
|
|
|
|
if is_new_user:
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(secrets.token_urlsafe(32)),
|
|
nickname=(nickname or "").strip(),
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
language_preference=_normalize_language(language),
|
|
)
|
|
|
|
async with transactional(self._db):
|
|
if code is not None:
|
|
await self._consume_sms_code_or_raise(
|
|
phone, code, "login", purposes=("login", "register")
|
|
)
|
|
if is_new_user:
|
|
try:
|
|
async with transactional_nested(self._db):
|
|
await repo.create_user(user, self._db)
|
|
await self._db.flush()
|
|
except IntegrityError as exc:
|
|
if is_user_phone_unique_violation(exc):
|
|
existing = await repo.get_user_by_phone(phone, self._db)
|
|
if existing is None:
|
|
raise
|
|
user = existing
|
|
is_new_user = False
|
|
else:
|
|
_raise_auth_error_from_user_integrity(
|
|
exc, phone_conflict="PHONE_EXISTS"
|
|
)
|
|
tokens = await self._issue_tokens(user.id, device_info)
|
|
if is_new_user:
|
|
await self._db.refresh(user)
|
|
|
|
return {"user": user, "is_new_user": is_new_user, **_public_tokens(tokens)}
|
|
|
|
async def mock_sms_login(
|
|
self,
|
|
phone: str,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""跳过短信校验的登录/自动注册(仅由 mock 路由在配置允许时调用)。"""
|
|
return await self._sms_login_after_code_verified(
|
|
phone,
|
|
device_info=device_info,
|
|
nickname=nickname,
|
|
language=language,
|
|
)
|
|
|
|
async def register_with_sms(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
password: str,
|
|
nickname: str,
|
|
email: str | None = None,
|
|
device_info: str = "",
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""SMS register. Returns {user, access_token, refresh_token}."""
|
|
precheck_error = await self._precheck_sms_code(phone, code, "register")
|
|
if precheck_error is not None:
|
|
raise AuthError(precheck_error, "INVALID_SMS_CODE")
|
|
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
|
|
if email and await repo.get_user_by_email(email, self._db):
|
|
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS")
|
|
|
|
return await self._register_after_sms_verified(
|
|
phone=phone,
|
|
code=code,
|
|
password=password,
|
|
nickname=nickname,
|
|
email=email,
|
|
device_info=device_info,
|
|
language=language,
|
|
)
|
|
|
|
async def _register_after_sms_verified(
|
|
self,
|
|
*,
|
|
phone: str,
|
|
code: str,
|
|
password: str,
|
|
nickname: str,
|
|
email: str | None,
|
|
device_info: str,
|
|
language: str | None,
|
|
) -> dict:
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(password),
|
|
email=email,
|
|
nickname=nickname,
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
language_preference=_normalize_language(language),
|
|
)
|
|
async with transactional(self._db):
|
|
await self._consume_sms_code_or_raise(phone, code, "register")
|
|
await _create_user_with_integrity_check(
|
|
self._db, user, phone_conflict="PHONE_EXISTS"
|
|
)
|
|
tokens = await self._issue_tokens(user_id, device_info)
|
|
await self._db.refresh(user)
|
|
return {"user": user, **_public_tokens(tokens)}
|
|
|
|
async def reset_password(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
new_password: str,
|
|
) -> None:
|
|
"""Reset password via SMS code."""
|
|
precheck_error = await self._precheck_sms_code(phone, code, "reset_password")
|
|
if precheck_error is not None:
|
|
raise AuthError(precheck_error, "INVALID_SMS_CODE")
|
|
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
async with transactional(self._db):
|
|
await self._consume_sms_code_or_raise(phone, code, "reset_password")
|
|
user.password_hash = hash_password(new_password)
|
|
|
|
async def change_password(
|
|
self,
|
|
user_id: str,
|
|
old_password: str,
|
|
new_password: str,
|
|
) -> None:
|
|
"""Change password (requires old password)."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
if not verify_password(old_password, user.password_hash):
|
|
raise AuthError("旧密码错误", "WRONG_PASSWORD")
|
|
|
|
async with transactional(self._db):
|
|
user.password_hash = hash_password(new_password)
|
|
|
|
async def change_phone(
|
|
self,
|
|
user_id: str,
|
|
new_phone: str,
|
|
code: str,
|
|
) -> User:
|
|
"""Change phone number via SMS code. Returns updated user."""
|
|
precheck_error = await self._precheck_sms_code(new_phone, code, "change_phone")
|
|
if precheck_error is not None:
|
|
raise AuthError(precheck_error, "INVALID_SMS_CODE")
|
|
|
|
existing = await repo.get_user_by_phone(new_phone, self._db)
|
|
if existing and existing.id != user_id:
|
|
raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN")
|
|
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
async with transactional(self._db):
|
|
await self._consume_sms_code_or_raise(new_phone, code, "change_phone")
|
|
user.phone = new_phone
|
|
try:
|
|
await self._db.flush()
|
|
except IntegrityError as exc:
|
|
_raise_auth_error_from_user_integrity(exc, phone_conflict="PHONE_TAKEN")
|
|
await self._db.refresh(user)
|
|
return user
|
|
|
|
async def update_nickname(self, user_id: str, nickname: str) -> User:
|
|
"""Update user nickname."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
async with transactional(self._db):
|
|
user.nickname = nickname.strip()
|
|
await self._db.refresh(user)
|
|
return user
|
|
|
|
async def update_avatar_url(self, user_id: str, avatar_url: str) -> User:
|
|
"""Update user avatar URL."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
async with transactional(self._db):
|
|
user.avatar_url = avatar_url
|
|
await self._db.refresh(user)
|
|
return user
|
|
|
|
async def upload_avatar(
|
|
self,
|
|
user_id: str,
|
|
file_content: bytes,
|
|
content_type: str,
|
|
*,
|
|
old_avatar_url: str | None,
|
|
) -> User:
|
|
"""Validate, process, upload avatar to COS, and persist URL."""
|
|
allowed_types = ["image/jpeg", "image/png", "image/webp"]
|
|
if content_type not in allowed_types:
|
|
raise BadRequestError(
|
|
f"不支持的文件类型。仅支持: {', '.join(allowed_types)}"
|
|
)
|
|
if not file_content:
|
|
raise BadRequestError("文件内容为空")
|
|
if len(file_content) > 5 * 1024 * 1024:
|
|
raise BadRequestError("文件大小超过5MB限制")
|
|
if not (
|
|
(settings.tencent_secret_id or "").strip()
|
|
and (settings.tencent_secret_key or "").strip()
|
|
and (settings.tencent_cos_bucket or "").strip()
|
|
):
|
|
raise ServiceUnavailableError("头像存储服务未配置,请稍后再试")
|
|
|
|
jpeg_bytes = await _process_avatar_jpeg_async(file_content)
|
|
cos_key = f"avatars/{user_id}.jpg"
|
|
old_key = extract_cos_object_key_if_owned(old_avatar_url) if old_avatar_url else None
|
|
|
|
if not self._object_storage:
|
|
raise ServiceUnavailableError("头像存储服务未配置,请稍后再试")
|
|
|
|
try:
|
|
avatar_url = await asyncio.to_thread(
|
|
self._object_storage.upload, cos_key, jpeg_bytes, "image/jpeg"
|
|
)
|
|
except Exception as exc:
|
|
from app.core.logging import get_logger
|
|
|
|
get_logger(__name__).exception("COS 头像上传失败: {}", exc)
|
|
raise ServiceUnavailableError("头像存储暂时不可用,请稍后再试") from exc
|
|
|
|
try:
|
|
user = await self.update_avatar_url(user_id, avatar_url)
|
|
except Exception:
|
|
try:
|
|
await asyncio.to_thread(self._object_storage.delete, cos_key)
|
|
except Exception as cleanup_exc:
|
|
from app.core.logging import get_logger
|
|
|
|
get_logger(__name__).warning(
|
|
"头像 DB 写入失败后清理 COS 对象失败: key={} err={}",
|
|
cos_key,
|
|
cleanup_exc,
|
|
)
|
|
raise
|
|
|
|
if old_key and old_key != cos_key:
|
|
best_effort_delete_cos_object_for_url(old_avatar_url)
|
|
return user
|
|
|
|
|
|
def _is_valid_image_header(header: bytes) -> bool:
|
|
if header.startswith(b"\xff\xd8\xff"):
|
|
return True
|
|
if header.startswith(b"\x89PNG\r\n\x1a\n"):
|
|
return True
|
|
if header.startswith(b"RIFF") and b"WEBP" in header[:12]:
|
|
return True
|
|
return False
|
|
|
|
|
|
async def _process_avatar_jpeg_async(file_content: bytes) -> bytes:
|
|
try:
|
|
return await asyncio.to_thread(_process_avatar_jpeg, file_content)
|
|
except BadRequestError:
|
|
raise
|
|
except Exception as exc:
|
|
raise BadRequestError("无效的图片文件") from exc
|
|
|
|
|
|
def _process_avatar_jpeg(file_content: bytes) -> bytes:
|
|
image_bytes = io.BytesIO(file_content)
|
|
header = image_bytes.read(16)
|
|
image_bytes.seek(0)
|
|
if not _is_valid_image_header(header):
|
|
raise BadRequestError(f"无效的图片文件格式。文件头: {header[:12].hex()}")
|
|
|
|
image = Image.open(image_bytes)
|
|
if image.mode != "RGB":
|
|
image = image.convert("RGB")
|
|
|
|
width, height = image.size
|
|
size = min(width, height)
|
|
left = (width - size) // 2
|
|
top = (height - size) // 2
|
|
image = image.crop((left, top, left + size, top + size))
|
|
if size > 512:
|
|
image = image.resize((512, 512), Image.Resampling.LANCZOS)
|
|
|
|
jpeg_buffer = io.BytesIO()
|
|
image.save(jpeg_buffer, format="JPEG", quality=85, optimize=True)
|
|
return jpeg_buffer.getvalue()
|