Files
life-echo/api/app/features/auth/service.py
Sully 105b50a277 merge dark mode and google OAuth (#35)
* feat(api): implement Google OAuth login and user management

- Added Google OpenID Connect login functionality, allowing users to authenticate using their Google accounts.
- Created new endpoints for Google login, including user registration and linking existing accounts.
- Introduced Google token verification logic and error handling for authentication failures.
- Updated environment configuration to include Google OAuth client IDs and verification settings.
- Enhanced user model to support OpenID and linked Google accounts.

This feature improves user experience by enabling seamless sign-in with Google, while maintaining security and integrity of user data.

* fix(auth): wire staging Google token verifier

* chore(deps): update expo to version 55.0.6 and adjust @expo/env dependency in pnpm-lock.yaml

* chore(deps): update Babel dependencies to version 7.29.7 in package-lock.json

* feat(auth): enhance phone login for China users

- Updated phone login functionality to support only mainland China (+86) mobile numbers.
- Added user prompts and descriptions for phone login, including confirmation and cancellation options.
- Adjusted translations for both English and Chinese to reflect the new phone login requirements.
- Updated Google OAuth client IDs in configuration files for production and staging environments.

* chore(deps): add peer flag to use-sync-external-store in package-lock.json

* chore(deps): add @emnapi/core and @emnapi/runtime to package-lock.json

* fix(app-expo): align Android native dependencies

* fix(app-expo): normalize lockfile for npm 10

* fix(config): update environment variable handling to use static access

- Introduced a static mapping for public environment variables to ensure proper access during the release bundle.
- Updated the `requirePublicEnv` and `optionalPublicEnv` functions to reference the new `PUBLIC_ENV` object instead of directly accessing `process.env`.
- Added comments to clarify the necessity of static access for certain environment variables.

* feat(app-expo): dark mode, FAQ i18n, eval ASR, and theme cleanup (#34)

* feat(app-expo): dark mode, FAQ i18n, version CI, and theme cleanup

Implement light/dark scene colors across chat, reading, and headers; remove
default/brand theme picker and ThemeVariablesProvider. Localize FAQ in-app,
fix dark-mode text visibility, and remove the unused /api/faqs endpoint.
Align About/version with Expo config and inject APP_VERSION in CI builds.

Also includes phone E164 auth/SMS updates, eval ASR page, and related API work.

* revert: remove phone E.164 changes from dark-mode branch

These auth/SMS internationalization updates were accidentally bundled into
the dark-mode commit; restore 11-digit CN phone flow and drop related API,
migration, and Expo UI work from this branch.

* fix: address PR review issues for dark mode and eval ASR

Use light foreground colors for sepia reading in dark mode, fix chat send
button contrast, stream-limit eval ASR uploads, restore LiveTester phone
validation, and remove unused AudioSegmenter code.

* fix(app-expo): improve chat send button contrast in light and dark mode

Add dedicated send button colors (accent fill in dark, primary fill in
light), use RNText to avoid NativeWind overrides, and restore dark labels
in light mode for readable composer actions.

---------

Co-authored-by: Kevin <kevin@brighteng.org>

---------

Co-authored-by: penghanyuan <penghanyuan@gmail.com>
Co-authored-by: Kevin <kevin@brighteng.org>
2026-06-09 11:14:36 +08:00

857 lines
31 KiB
Python

import asyncio
import io
import random
import secrets
import uuid
from datetime import datetime, timedelta, timezone
from PIL import Image
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.cos_url_keys import (
best_effort_delete_cos_object_for_url,
extract_cos_object_key_if_owned,
)
from app.core.db import transactional, transactional_nested, utc_now
from app.core.errors import (
BadRequestError,
ProviderError,
RateLimitedError,
ServiceUnavailableError,
)
from app.core.logging import get_logger
from app.core.security import (
create_access_token,
get_token_expires_at,
hash_password,
verify_password,
)
from app.core.security import (
create_refresh_token as generate_refresh_token_str,
)
from app.features.auth import repo
from app.features.auth.google import GoogleIdentity, verify_google_id_token
from app.features.auth.integrity import (
is_user_phone_unique_violation,
user_integrity_auth_code,
)
from app.features.auth.models import RefreshToken, SmsVerificationCode
from app.features.auth.service_errors import AuthError
from app.features.user.models import User
from app.ports.sms import SmsSender
from app.ports.storage import ObjectStorage
logger = get_logger(__name__)
CODE_LENGTH = 6
CODE_EXPIRE_MINUTES = 5
RATE_LIMIT_SECONDS = 60
_SMS_CONSUME_FALLBACK = "验证码不存在或已使用"
def _sms_is_configured() -> bool:
return bool(
(settings.tencent_secret_id or "").strip()
and (settings.tencent_secret_key or "").strip()
and (settings.tencent_sms_sdk_app_id or "").strip()
and (settings.tencent_sms_sign_name or "").strip()
and (settings.tencent_sms_template_id or "").strip()
)
_VALID_LANGUAGES = {"zh", "en"}
def _normalize_language(lang: str | None) -> str:
"""Normalize device language token; default to zh on missing/unknown."""
if not lang:
return "zh"
s = str(lang).strip().lower()
return s if s in _VALID_LANGUAGES else "zh"
def _as_utc(dt: datetime) -> datetime:
"""Normalize DB datetimes for safe comparison (sqlite may return naive)."""
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def _raise_auth_error_from_user_integrity(
exc: IntegrityError,
*,
phone_conflict: str,
) -> None:
code = user_integrity_auth_code(exc, phone_conflict=phone_conflict)
if code == "PHONE_EXISTS":
raise AuthError("该手机号已被注册", "PHONE_EXISTS") from exc
if code == "EMAIL_EXISTS":
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS") from exc
if code == "PHONE_TAKEN":
raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN") from exc
raise exc
async def _create_user_with_integrity_check(
db: AsyncSession,
user: User,
*,
phone_conflict: str,
) -> None:
await repo.create_user(user, db)
try:
await db.flush()
except IntegrityError as exc:
_raise_auth_error_from_user_integrity(exc, phone_conflict=phone_conflict)
def _public_tokens(issued: dict) -> dict:
"""Strip internal fields before returning tokens to callers."""
return {
"access_token": issued["access_token"],
"refresh_token": issued["refresh_token"],
}
def _google_openid(subject: str) -> str:
return f"google:{subject}"
def _google_internal_phone(subject: str) -> str:
return f"google:{subject}"
def _nickname_from_google_identity(identity: GoogleIdentity) -> str:
name = identity.name.strip()
if name:
return name[:50]
local_part = identity.email.split("@", 1)[0].strip()
return (local_part or "Google User")[:50]
class AuthService:
def __init__(
self,
db: AsyncSession,
sms: SmsSender,
*,
object_storage: ObjectStorage | None = None,
):
self._db = db
self._sms = sms
self._object_storage = object_storage
# ── private helpers ──────────────────────────────────────
def _generate_code(self) -> str:
return "".join(str(random.randint(0, 9)) for _ in range(CODE_LENGTH))
async def _check_sms_code(
self, phone: str, code: str, purpose: str
) -> tuple[SmsVerificationCode | None, str]:
"""Validate SMS code without consuming it. Returns (record, message).
UX pre-check only; authoritative validation is ``try_consume_verification_code``
inside a transaction.
"""
record = await repo.get_latest_unused_code(phone, purpose, self._db)
if not record:
return None, _SMS_CONSUME_FALLBACK
now = utc_now()
if now > record.expires_at:
async with transactional(self._db):
record.is_expired = True
return None, "验证码已过期"
if record.code != code:
return None, "验证码错误"
return record, "验证成功"
async def _precheck_sms_code(
self, phone: str, code: str, purpose: str
) -> str | None:
"""UX pre-check: fast-fail without consuming. None means likely valid."""
record, message = await self._check_sms_code(phone, code, purpose)
if record is None:
return message
return None
async def _precheck_sms_code_for_purposes(
self, phone: str, code: str, purposes: tuple[str, ...]
) -> str | None:
"""Login flow: try each purpose until one pre-check passes."""
last_message = _SMS_CONSUME_FALLBACK
for purpose in purposes:
record, message = await self._check_sms_code(phone, code, purpose)
if record is not None:
return None
last_message = message
return last_message
async def _sms_invalid_message_after_consume_failure(
self,
phone: str,
code: str,
purposes: tuple[str, ...],
*,
fallback: str = _SMS_CONSUME_FALLBACK,
) -> str:
"""Re-read code state after atomic consume failed (race/expiry/concurrency)."""
for purpose in purposes:
record, message = await self._check_sms_code(phone, code, purpose)
if record is None:
return message
return fallback
async def _consume_sms_code_or_raise(
self,
phone: str,
code: str,
purpose: str,
*,
purposes: tuple[str, ...] | None = None,
) -> SmsVerificationCode:
"""Atomically consume SMS code inside ``transactional()``; raise on failure."""
purposes_to_try = purposes or (purpose,)
for p in purposes_to_try:
consumed = await repo.try_consume_verification_code(
phone, code, p, self._db
)
if consumed is not None:
return consumed
message = await self._sms_invalid_message_after_consume_failure(
phone, code, purposes_to_try
)
raise AuthError(message, "INVALID_SMS_CODE")
async def send_sms_code(
self,
phone: str,
purpose: str,
ip_address: str | None = None,
) -> tuple[bool, str, int]:
"""Send SMS verification code. Returns (True, message, expires_in_seconds) on success."""
if not _sms_is_configured():
raise ServiceUnavailableError("短信服务未配置,请稍后再试")
if purpose == "register":
if await repo.get_user_by_phone(phone, self._db):
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
if purpose == "reset_password":
if not await repo.get_user_by_phone(phone, self._db):
raise AuthError("该手机号未注册", "USER_NOT_FOUND")
recent = await repo.get_recent_code_for_rate_limit(phone, self._db)
if recent:
now = utc_now()
elapsed = (now - recent.created_at).total_seconds()
if elapsed < RATE_LIMIT_SECONDS:
remaining = int(RATE_LIMIT_SECONDS - elapsed)
raise RateLimitedError(f"发送过于频繁,请{remaining}秒后再试")
code = self._generate_code()
expires_at = utc_now() + timedelta(minutes=CODE_EXPIRE_MINUTES)
record = SmsVerificationCode(
id=str(uuid.uuid4()),
phone=phone,
code=code,
purpose=purpose,
expires_at=expires_at,
ip_address=ip_address,
)
async with transactional(self._db):
await repo.create_verification_code(record, self._db)
if not self._sms.send_verification_code(phone, code):
async with transactional(self._db):
await repo.mark_verification_code_expired(record.id, self._db)
raise ProviderError("短信发送失败,请稍后重试")
return True, "验证码已发送", CODE_EXPIRE_MINUTES * 60
async def _issue_tokens(self, user_id: str, device_info: str = "") -> dict:
"""Create access + refresh token pair and add refresh token to session."""
refresh_str = generate_refresh_token_str()
token = RefreshToken(
id=str(uuid.uuid4()),
user_id=user_id,
token=refresh_str,
expires_at=get_token_expires_at(),
created_at=datetime.now(timezone.utc),
device_info=device_info or None,
)
await repo.create_refresh_token(token, self._db)
access = create_access_token(data={"sub": user_id})
return {
"access_token": access,
"refresh_token": refresh_str,
"refresh_token_id": token.id,
}
async def _try_idempotent_refresh_within_grace(
self,
token_record: RefreshToken,
) -> dict | None:
"""Grace-window retry: return new access + existing replacement refresh."""
grace = settings.refresh_token_reuse_grace_seconds
if grace <= 0:
return None
if not token_record.replaced_by_token_id or token_record.rotated_at is None:
return None
now = utc_now()
rotated_at = _as_utc(token_record.rotated_at)
if rotated_at + timedelta(seconds=grace) <= now:
return None
replacement = await repo.get_refresh_token_by_id(
token_record.replaced_by_token_id, self._db
)
if replacement is None or replacement.is_revoked:
return None
if _as_utc(replacement.expires_at) < now:
return None
user = await repo.get_user_by_id(replacement.user_id, self._db)
if user is None:
return None
access = create_access_token(data={"sub": user.id})
return {
"access_token": access,
"refresh_token": replacement.token,
}
# ── public API ───────────────────────────────────────────
async def register(
self,
phone: str,
password: str,
nickname: str,
email: str | None = None,
language: str | None = None,
) -> dict:
"""Register new user. Returns {user, access_token, refresh_token}."""
if await repo.get_user_by_phone(phone, self._db):
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
if email and await repo.get_user_by_email(email, self._db):
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS")
user_id = str(uuid.uuid4())
user = User(
id=user_id,
phone=phone,
password_hash=hash_password(password),
email=email,
nickname=nickname,
subscription_type="free",
created_at=datetime.now(timezone.utc),
language_preference=_normalize_language(language),
)
async with transactional(self._db):
await _create_user_with_integrity_check(
self._db, user, phone_conflict="PHONE_EXISTS"
)
tokens = await self._issue_tokens(user_id)
await self._db.refresh(user)
return {"user": user, **_public_tokens(tokens)}
async def login(
self,
phone: str,
password: str,
device_info: str = "",
) -> dict:
"""Login with phone+password. Returns {user, access_token, refresh_token}."""
user = await repo.get_user_by_phone(phone, self._db)
if not user or not verify_password(password, user.password_hash):
raise AuthError("手机号或密码错误", "INVALID_CREDENTIALS")
async with transactional(self._db):
tokens = await self._issue_tokens(user.id, device_info)
return {"user": user, **_public_tokens(tokens)}
async def login_with_google(
self,
id_token: str,
device_info: str = "",
language: str | None = None,
) -> dict:
"""Google OpenID Connect login. Auto-registers a user on first sign-in."""
identity = await asyncio.to_thread(verify_google_id_token, id_token)
openid = _google_openid(identity.subject)
user = await repo.get_user_by_openid(openid, self._db)
is_new_user = False
async with transactional(self._db):
if user is None:
user = await repo.get_user_by_email(identity.email, self._db)
if user is not None:
if user.openid and user.openid != openid:
raise AuthError("该邮箱已绑定其他第三方账号", "EMAIL_EXISTS")
user.openid = openid
if not user.email:
user.email = identity.email
if identity.picture and not user.avatar_url:
user.avatar_url = identity.picture
else:
is_new_user = True
user = User(
id=str(uuid.uuid4()),
phone=_google_internal_phone(identity.subject),
password_hash=hash_password(secrets.token_urlsafe(32)),
email=identity.email,
openid=openid,
nickname=_nickname_from_google_identity(identity),
avatar_url=identity.picture,
subscription_type="free",
created_at=datetime.now(timezone.utc),
language_preference=_normalize_language(language),
)
await _create_user_with_integrity_check(
self._db, user, phone_conflict="PHONE_EXISTS"
)
tokens = await self._issue_tokens(user.id, device_info)
if is_new_user:
await self._db.refresh(user)
return {"user": user, "is_new_user": is_new_user, **_public_tokens(tokens)}
async def _revoke_all_active_tokens_in_session(self, user_id: str) -> int:
"""Revoke all active refresh tokens on the current session (no commit)."""
tokens = await repo.get_active_tokens_for_user(user_id, self._db)
for token in tokens:
token.is_revoked = True
return len(tokens)
async def refresh_tokens(
self,
refresh_token: str,
device_info: str = "",
) -> dict:
"""Rotate refresh token and issue a new access token pair."""
reuse_detected = False
async with transactional(self._db):
consumed = await repo.try_consume_refresh_token(refresh_token, self._db)
if consumed is not None:
user = await repo.get_user_by_id(consumed.user_id, self._db)
if not user:
raise AuthError("用户不存在", "USER_NOT_FOUND")
issued = await self._issue_tokens(
consumed.user_id,
device_info or (consumed.device_info or ""),
)
await repo.link_refresh_rotation(
consumed.id,
issued["refresh_token_id"],
utc_now(),
self._db,
)
return _public_tokens(issued)
token_record = await repo.get_refresh_token_by_token(
refresh_token, self._db
)
if not token_record:
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
if token_record.is_revoked:
idempotent = await self._try_idempotent_refresh_within_grace(
token_record
)
if idempotent is None and not token_record.replaced_by_token_id:
# Concurrent refresh may observe revoke before lineage commits.
token_record = (
await repo.get_refresh_token_by_token(refresh_token, self._db)
or token_record
)
idempotent = await self._try_idempotent_refresh_within_grace(
token_record
)
if idempotent is not None:
return idempotent
grace = settings.refresh_token_reuse_grace_seconds
rotated_at = token_record.rotated_at
still_in_grace = (
grace > 0
and rotated_at is not None
and _as_utc(rotated_at) + timedelta(seconds=grace) > utc_now()
)
if still_in_grace:
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
if (
grace > 0
and rotated_at is None
and not token_record.replaced_by_token_id
):
# Revoke visible but lineage not committed yet (concurrent rotation).
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
logger.bind(user_id=token_record.user_id).warning(
"Refresh token reuse detected (grace expired or no lineage)"
)
await self._revoke_all_active_tokens_in_session(token_record.user_id)
reuse_detected = True
elif _as_utc(token_record.expires_at) < utc_now():
raise AuthError("刷新令牌已过期", "TOKEN_EXPIRED")
else:
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
if reuse_detected:
raise AuthError("刷新令牌已失效,请重新登录", "REFRESH_TOKEN_REUSE")
async def logout(self, refresh_token: str, user_id: str) -> None:
"""Revoke a refresh token owned by the given user."""
token_record = await repo.get_refresh_token_by_token(refresh_token, self._db)
if token_record and token_record.user_id == user_id:
async with transactional(self._db):
token_record.is_revoked = True
async def logout_all(self, user_id: str) -> int:
"""Revoke all refresh tokens for user. Returns count revoked."""
async with transactional(self._db):
return await self._revoke_all_active_tokens_in_session(user_id)
async def login_with_sms(
self,
phone: str,
code: str,
device_info: str = "",
nickname: str | None = None,
language: str | None = None,
) -> dict:
"""SMS login (auto-register if new). Returns {user, access_token, refresh_token, is_new_user}."""
precheck_error = await self._precheck_sms_code_for_purposes(
phone, code, ("login", "register")
)
if precheck_error is not None:
raise AuthError(precheck_error, "INVALID_SMS_CODE")
return await self._sms_login_after_code_verified(
phone,
code=code,
device_info=device_info,
nickname=nickname,
language=language,
)
async def _sms_login_after_code_verified(
self,
phone: str,
*,
code: str | None = None,
device_info: str = "",
nickname: str | None = None,
language: str | None = None,
) -> dict:
"""SMS 校验通过后:在同一事务内原子消耗验证码、创建用户(如需)并签发令牌。
``language`` 仅在「新用户」分支下写入;命中已有用户时不覆盖偏好。
mock 路由传入 ``code=None`` 跳过验证码消耗。
"""
user = await repo.get_user_by_phone(phone, self._db)
is_new_user = user is None
if is_new_user:
user_id = str(uuid.uuid4())
user = User(
id=user_id,
phone=phone,
password_hash=hash_password(secrets.token_urlsafe(32)),
nickname=(nickname or "").strip(),
subscription_type="free",
created_at=datetime.now(timezone.utc),
language_preference=_normalize_language(language),
)
async with transactional(self._db):
if code is not None:
await self._consume_sms_code_or_raise(
phone, code, "login", purposes=("login", "register")
)
if is_new_user:
try:
async with transactional_nested(self._db):
await repo.create_user(user, self._db)
await self._db.flush()
except IntegrityError as exc:
if is_user_phone_unique_violation(exc):
existing = await repo.get_user_by_phone(phone, self._db)
if existing is None:
raise
user = existing
is_new_user = False
else:
_raise_auth_error_from_user_integrity(
exc, phone_conflict="PHONE_EXISTS"
)
tokens = await self._issue_tokens(user.id, device_info)
if is_new_user:
await self._db.refresh(user)
return {"user": user, "is_new_user": is_new_user, **_public_tokens(tokens)}
async def mock_sms_login(
self,
phone: str,
device_info: str = "",
nickname: str | None = None,
language: str | None = None,
) -> dict:
"""跳过短信校验的登录/自动注册(仅由 mock 路由在配置允许时调用)。"""
return await self._sms_login_after_code_verified(
phone,
device_info=device_info,
nickname=nickname,
language=language,
)
async def register_with_sms(
self,
phone: str,
code: str,
password: str,
nickname: str,
email: str | None = None,
device_info: str = "",
language: str | None = None,
) -> dict:
"""SMS register. Returns {user, access_token, refresh_token}."""
precheck_error = await self._precheck_sms_code(phone, code, "register")
if precheck_error is not None:
raise AuthError(precheck_error, "INVALID_SMS_CODE")
if await repo.get_user_by_phone(phone, self._db):
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
if email and await repo.get_user_by_email(email, self._db):
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS")
return await self._register_after_sms_verified(
phone=phone,
code=code,
password=password,
nickname=nickname,
email=email,
device_info=device_info,
language=language,
)
async def _register_after_sms_verified(
self,
*,
phone: str,
code: str,
password: str,
nickname: str,
email: str | None,
device_info: str,
language: str | None,
) -> dict:
user_id = str(uuid.uuid4())
user = User(
id=user_id,
phone=phone,
password_hash=hash_password(password),
email=email,
nickname=nickname,
subscription_type="free",
created_at=datetime.now(timezone.utc),
language_preference=_normalize_language(language),
)
async with transactional(self._db):
await self._consume_sms_code_or_raise(phone, code, "register")
await _create_user_with_integrity_check(
self._db, user, phone_conflict="PHONE_EXISTS"
)
tokens = await self._issue_tokens(user_id, device_info)
await self._db.refresh(user)
return {"user": user, **_public_tokens(tokens)}
async def reset_password(
self,
phone: str,
code: str,
new_password: str,
) -> None:
"""Reset password via SMS code."""
precheck_error = await self._precheck_sms_code(phone, code, "reset_password")
if precheck_error is not None:
raise AuthError(precheck_error, "INVALID_SMS_CODE")
user = await repo.get_user_by_phone(phone, self._db)
if not user:
raise AuthError("用户不存在", "USER_NOT_FOUND")
async with transactional(self._db):
await self._consume_sms_code_or_raise(phone, code, "reset_password")
user.password_hash = hash_password(new_password)
async def change_password(
self,
user_id: str,
old_password: str,
new_password: str,
) -> None:
"""Change password (requires old password)."""
user = await repo.get_user_by_id(user_id, self._db)
if not user:
raise AuthError("用户不存在", "USER_NOT_FOUND")
if not verify_password(old_password, user.password_hash):
raise AuthError("旧密码错误", "WRONG_PASSWORD")
async with transactional(self._db):
user.password_hash = hash_password(new_password)
async def change_phone(
self,
user_id: str,
new_phone: str,
code: str,
) -> User:
"""Change phone number via SMS code. Returns updated user."""
precheck_error = await self._precheck_sms_code(new_phone, code, "change_phone")
if precheck_error is not None:
raise AuthError(precheck_error, "INVALID_SMS_CODE")
existing = await repo.get_user_by_phone(new_phone, self._db)
if existing and existing.id != user_id:
raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN")
user = await repo.get_user_by_id(user_id, self._db)
if not user:
raise AuthError("用户不存在", "USER_NOT_FOUND")
async with transactional(self._db):
await self._consume_sms_code_or_raise(new_phone, code, "change_phone")
user.phone = new_phone
try:
await self._db.flush()
except IntegrityError as exc:
_raise_auth_error_from_user_integrity(exc, phone_conflict="PHONE_TAKEN")
await self._db.refresh(user)
return user
async def update_nickname(self, user_id: str, nickname: str) -> User:
"""Update user nickname."""
user = await repo.get_user_by_id(user_id, self._db)
if not user:
raise AuthError("用户不存在", "USER_NOT_FOUND")
async with transactional(self._db):
user.nickname = nickname.strip()
await self._db.refresh(user)
return user
async def update_avatar_url(self, user_id: str, avatar_url: str) -> User:
"""Update user avatar URL."""
user = await repo.get_user_by_id(user_id, self._db)
if not user:
raise AuthError("用户不存在", "USER_NOT_FOUND")
async with transactional(self._db):
user.avatar_url = avatar_url
await self._db.refresh(user)
return user
async def upload_avatar(
self,
user_id: str,
file_content: bytes,
content_type: str,
*,
old_avatar_url: str | None,
) -> User:
"""Validate, process, upload avatar to COS, and persist URL."""
allowed_types = ["image/jpeg", "image/png", "image/webp"]
if content_type not in allowed_types:
raise BadRequestError(
f"不支持的文件类型。仅支持: {', '.join(allowed_types)}"
)
if not file_content:
raise BadRequestError("文件内容为空")
if len(file_content) > 5 * 1024 * 1024:
raise BadRequestError("文件大小超过5MB限制")
if not (
(settings.tencent_secret_id or "").strip()
and (settings.tencent_secret_key or "").strip()
and (settings.tencent_cos_bucket or "").strip()
):
raise ServiceUnavailableError("头像存储服务未配置,请稍后再试")
jpeg_bytes = await _process_avatar_jpeg_async(file_content)
cos_key = f"avatars/{user_id}.jpg"
old_key = (
extract_cos_object_key_if_owned(old_avatar_url) if old_avatar_url else None
)
if not self._object_storage:
raise ServiceUnavailableError("头像存储服务未配置,请稍后再试")
try:
avatar_url = await asyncio.to_thread(
self._object_storage.upload, cos_key, jpeg_bytes, "image/jpeg"
)
except Exception as exc:
from app.core.logging import get_logger
get_logger(__name__).exception("COS 头像上传失败: {}", exc)
raise ServiceUnavailableError("头像存储暂时不可用,请稍后再试") from exc
try:
user = await self.update_avatar_url(user_id, avatar_url)
except Exception:
try:
await asyncio.to_thread(self._object_storage.delete, cos_key)
except Exception as cleanup_exc:
from app.core.logging import get_logger
get_logger(__name__).warning(
"头像 DB 写入失败后清理 COS 对象失败: key={} err={}",
cos_key,
cleanup_exc,
)
raise
if old_key and old_key != cos_key:
best_effort_delete_cos_object_for_url(old_avatar_url)
return user
def _is_valid_image_header(header: bytes) -> bool:
if header.startswith(b"\xff\xd8\xff"):
return True
if header.startswith(b"\x89PNG\r\n\x1a\n"):
return True
if header.startswith(b"RIFF") and b"WEBP" in header[:12]:
return True
return False
async def _process_avatar_jpeg_async(file_content: bytes) -> bytes:
try:
return await asyncio.to_thread(_process_avatar_jpeg, file_content)
except BadRequestError:
raise
except Exception as exc:
raise BadRequestError("无效的图片文件") from exc
def _process_avatar_jpeg(file_content: bytes) -> bytes:
image_bytes = io.BytesIO(file_content)
header = image_bytes.read(16)
image_bytes.seek(0)
if not _is_valid_image_header(header):
raise BadRequestError(f"无效的图片文件格式。文件头: {header[:12].hex()}")
image = Image.open(image_bytes)
if image.mode != "RGB":
image = image.convert("RGB")
width, height = image.size
size = min(width, height)
left = (width - size) // 2
top = (height - size) // 2
image = image.crop((left, top, left + size, top + size))
if size > 512:
image = image.resize((512, 512), Image.Resampling.LANCZOS)
jpeg_buffer = io.BytesIO()
image.save(jpeg_buffer, format="JPEG", quality=85, optimize=True)
return jpeg_buffer.getvalue()