* feat(api): implement Google OAuth login and user management - Added Google OpenID Connect login functionality, allowing users to authenticate using their Google accounts. - Created new endpoints for Google login, including user registration and linking existing accounts. - Introduced Google token verification logic and error handling for authentication failures. - Updated environment configuration to include Google OAuth client IDs and verification settings. - Enhanced user model to support OpenID and linked Google accounts. This feature improves user experience by enabling seamless sign-in with Google, while maintaining security and integrity of user data. * fix(auth): wire staging Google token verifier * chore(deps): update expo to version 55.0.6 and adjust @expo/env dependency in pnpm-lock.yaml * chore(deps): update Babel dependencies to version 7.29.7 in package-lock.json * feat(auth): enhance phone login for China users - Updated phone login functionality to support only mainland China (+86) mobile numbers. - Added user prompts and descriptions for phone login, including confirmation and cancellation options. - Adjusted translations for both English and Chinese to reflect the new phone login requirements. - Updated Google OAuth client IDs in configuration files for production and staging environments. * chore(deps): add peer flag to use-sync-external-store in package-lock.json * chore(deps): add @emnapi/core and @emnapi/runtime to package-lock.json * fix(app-expo): align Android native dependencies * fix(app-expo): normalize lockfile for npm 10 * fix(config): update environment variable handling to use static access - Introduced a static mapping for public environment variables to ensure proper access during the release bundle. - Updated the `requirePublicEnv` and `optionalPublicEnv` functions to reference the new `PUBLIC_ENV` object instead of directly accessing `process.env`. - Added comments to clarify the necessity of static access for certain environment variables. * feat(app-expo): dark mode, FAQ i18n, eval ASR, and theme cleanup (#34) * feat(app-expo): dark mode, FAQ i18n, version CI, and theme cleanup Implement light/dark scene colors across chat, reading, and headers; remove default/brand theme picker and ThemeVariablesProvider. Localize FAQ in-app, fix dark-mode text visibility, and remove the unused /api/faqs endpoint. Align About/version with Expo config and inject APP_VERSION in CI builds. Also includes phone E164 auth/SMS updates, eval ASR page, and related API work. * revert: remove phone E.164 changes from dark-mode branch These auth/SMS internationalization updates were accidentally bundled into the dark-mode commit; restore 11-digit CN phone flow and drop related API, migration, and Expo UI work from this branch. * fix: address PR review issues for dark mode and eval ASR Use light foreground colors for sepia reading in dark mode, fix chat send button contrast, stream-limit eval ASR uploads, restore LiveTester phone validation, and remove unused AudioSegmenter code. * fix(app-expo): improve chat send button contrast in light and dark mode Add dedicated send button colors (accent fill in dark, primary fill in light), use RNText to avoid NativeWind overrides, and restore dark labels in light mode for readable composer actions. --------- Co-authored-by: Kevin <kevin@brighteng.org> --------- Co-authored-by: penghanyuan <penghanyuan@gmail.com> Co-authored-by: Kevin <kevin@brighteng.org>
857 lines
31 KiB
Python
857 lines
31 KiB
Python
import asyncio
|
|
import io
|
|
import random
|
|
import secrets
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from PIL import Image
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.cos_url_keys import (
|
|
best_effort_delete_cos_object_for_url,
|
|
extract_cos_object_key_if_owned,
|
|
)
|
|
from app.core.db import transactional, transactional_nested, utc_now
|
|
from app.core.errors import (
|
|
BadRequestError,
|
|
ProviderError,
|
|
RateLimitedError,
|
|
ServiceUnavailableError,
|
|
)
|
|
from app.core.logging import get_logger
|
|
from app.core.security import (
|
|
create_access_token,
|
|
get_token_expires_at,
|
|
hash_password,
|
|
verify_password,
|
|
)
|
|
from app.core.security import (
|
|
create_refresh_token as generate_refresh_token_str,
|
|
)
|
|
from app.features.auth import repo
|
|
from app.features.auth.google import GoogleIdentity, verify_google_id_token
|
|
from app.features.auth.integrity import (
|
|
is_user_phone_unique_violation,
|
|
user_integrity_auth_code,
|
|
)
|
|
from app.features.auth.models import RefreshToken, SmsVerificationCode
|
|
from app.features.auth.service_errors import AuthError
|
|
from app.features.user.models import User
|
|
from app.ports.sms import SmsSender
|
|
from app.ports.storage import ObjectStorage
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
CODE_LENGTH = 6
|
|
CODE_EXPIRE_MINUTES = 5
|
|
RATE_LIMIT_SECONDS = 60
|
|
_SMS_CONSUME_FALLBACK = "验证码不存在或已使用"
|
|
|
|
|
|
def _sms_is_configured() -> bool:
|
|
return bool(
|
|
(settings.tencent_secret_id or "").strip()
|
|
and (settings.tencent_secret_key or "").strip()
|
|
and (settings.tencent_sms_sdk_app_id or "").strip()
|
|
and (settings.tencent_sms_sign_name or "").strip()
|
|
and (settings.tencent_sms_template_id or "").strip()
|
|
)
|
|
|
|
|
|
_VALID_LANGUAGES = {"zh", "en"}
|
|
|
|
|
|
def _normalize_language(lang: str | None) -> str:
|
|
"""Normalize device language token; default to zh on missing/unknown."""
|
|
if not lang:
|
|
return "zh"
|
|
s = str(lang).strip().lower()
|
|
return s if s in _VALID_LANGUAGES else "zh"
|
|
|
|
|
|
def _as_utc(dt: datetime) -> datetime:
|
|
"""Normalize DB datetimes for safe comparison (sqlite may return naive)."""
|
|
if dt.tzinfo is None:
|
|
return dt.replace(tzinfo=timezone.utc)
|
|
return dt.astimezone(timezone.utc)
|
|
|
|
|
|
def _raise_auth_error_from_user_integrity(
|
|
exc: IntegrityError,
|
|
*,
|
|
phone_conflict: str,
|
|
) -> None:
|
|
code = user_integrity_auth_code(exc, phone_conflict=phone_conflict)
|
|
if code == "PHONE_EXISTS":
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS") from exc
|
|
if code == "EMAIL_EXISTS":
|
|
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS") from exc
|
|
if code == "PHONE_TAKEN":
|
|
raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN") from exc
|
|
raise exc
|
|
|
|
|
|
async def _create_user_with_integrity_check(
|
|
db: AsyncSession,
|
|
user: User,
|
|
*,
|
|
phone_conflict: str,
|
|
) -> None:
|
|
await repo.create_user(user, db)
|
|
try:
|
|
await db.flush()
|
|
except IntegrityError as exc:
|
|
_raise_auth_error_from_user_integrity(exc, phone_conflict=phone_conflict)
|
|
|
|
|
|
def _public_tokens(issued: dict) -> dict:
|
|
"""Strip internal fields before returning tokens to callers."""
|
|
return {
|
|
"access_token": issued["access_token"],
|
|
"refresh_token": issued["refresh_token"],
|
|
}
|
|
|
|
|
|
def _google_openid(subject: str) -> str:
|
|
return f"google:{subject}"
|
|
|
|
|
|
def _google_internal_phone(subject: str) -> str:
|
|
return f"google:{subject}"
|
|
|
|
|
|
def _nickname_from_google_identity(identity: GoogleIdentity) -> str:
|
|
name = identity.name.strip()
|
|
if name:
|
|
return name[:50]
|
|
local_part = identity.email.split("@", 1)[0].strip()
|
|
return (local_part or "Google User")[:50]
|
|
|
|
|
|
class AuthService:
|
|
def __init__(
|
|
self,
|
|
db: AsyncSession,
|
|
sms: SmsSender,
|
|
*,
|
|
object_storage: ObjectStorage | None = None,
|
|
):
|
|
self._db = db
|
|
self._sms = sms
|
|
self._object_storage = object_storage
|
|
|
|
# ── private helpers ──────────────────────────────────────
|
|
|
|
def _generate_code(self) -> str:
|
|
return "".join(str(random.randint(0, 9)) for _ in range(CODE_LENGTH))
|
|
|
|
async def _check_sms_code(
|
|
self, phone: str, code: str, purpose: str
|
|
) -> tuple[SmsVerificationCode | None, str]:
|
|
"""Validate SMS code without consuming it. Returns (record, message).
|
|
|
|
UX pre-check only; authoritative validation is ``try_consume_verification_code``
|
|
inside a transaction.
|
|
"""
|
|
record = await repo.get_latest_unused_code(phone, purpose, self._db)
|
|
if not record:
|
|
return None, _SMS_CONSUME_FALLBACK
|
|
now = utc_now()
|
|
if now > record.expires_at:
|
|
async with transactional(self._db):
|
|
record.is_expired = True
|
|
return None, "验证码已过期"
|
|
if record.code != code:
|
|
return None, "验证码错误"
|
|
return record, "验证成功"
|
|
|
|
async def _precheck_sms_code(
|
|
self, phone: str, code: str, purpose: str
|
|
) -> str | None:
|
|
"""UX pre-check: fast-fail without consuming. None means likely valid."""
|
|
record, message = await self._check_sms_code(phone, code, purpose)
|
|
if record is None:
|
|
return message
|
|
return None
|
|
|
|
async def _precheck_sms_code_for_purposes(
|
|
self, phone: str, code: str, purposes: tuple[str, ...]
|
|
) -> str | None:
|
|
"""Login flow: try each purpose until one pre-check passes."""
|
|
last_message = _SMS_CONSUME_FALLBACK
|
|
for purpose in purposes:
|
|
record, message = await self._check_sms_code(phone, code, purpose)
|
|
if record is not None:
|
|
return None
|
|
last_message = message
|
|
return last_message
|
|
|
|
async def _sms_invalid_message_after_consume_failure(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
purposes: tuple[str, ...],
|
|
*,
|
|
fallback: str = _SMS_CONSUME_FALLBACK,
|
|
) -> str:
|
|
"""Re-read code state after atomic consume failed (race/expiry/concurrency)."""
|
|
for purpose in purposes:
|
|
record, message = await self._check_sms_code(phone, code, purpose)
|
|
if record is None:
|
|
return message
|
|
return fallback
|
|
|
|
async def _consume_sms_code_or_raise(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
purpose: str,
|
|
*,
|
|
purposes: tuple[str, ...] | None = None,
|
|
) -> SmsVerificationCode:
|
|
"""Atomically consume SMS code inside ``transactional()``; raise on failure."""
|
|
purposes_to_try = purposes or (purpose,)
|
|
for p in purposes_to_try:
|
|
consumed = await repo.try_consume_verification_code(
|
|
phone, code, p, self._db
|
|
)
|
|
if consumed is not None:
|
|
return consumed
|
|
message = await self._sms_invalid_message_after_consume_failure(
|
|
phone, code, purposes_to_try
|
|
)
|
|
raise AuthError(message, "INVALID_SMS_CODE")
|
|
|
|
async def send_sms_code(
|
|
self,
|
|
phone: str,
|
|
purpose: str,
|
|
ip_address: str | None = None,
|
|
) -> tuple[bool, str, int]:
|
|
"""Send SMS verification code. Returns (True, message, expires_in_seconds) on success."""
|
|
if not _sms_is_configured():
|
|
raise ServiceUnavailableError("短信服务未配置,请稍后再试")
|
|
if purpose == "register":
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
if purpose == "reset_password":
|
|
if not await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号未注册", "USER_NOT_FOUND")
|
|
recent = await repo.get_recent_code_for_rate_limit(phone, self._db)
|
|
if recent:
|
|
now = utc_now()
|
|
elapsed = (now - recent.created_at).total_seconds()
|
|
if elapsed < RATE_LIMIT_SECONDS:
|
|
remaining = int(RATE_LIMIT_SECONDS - elapsed)
|
|
raise RateLimitedError(f"发送过于频繁,请{remaining}秒后再试")
|
|
code = self._generate_code()
|
|
expires_at = utc_now() + timedelta(minutes=CODE_EXPIRE_MINUTES)
|
|
record = SmsVerificationCode(
|
|
id=str(uuid.uuid4()),
|
|
phone=phone,
|
|
code=code,
|
|
purpose=purpose,
|
|
expires_at=expires_at,
|
|
ip_address=ip_address,
|
|
)
|
|
async with transactional(self._db):
|
|
await repo.create_verification_code(record, self._db)
|
|
|
|
if not self._sms.send_verification_code(phone, code):
|
|
async with transactional(self._db):
|
|
await repo.mark_verification_code_expired(record.id, self._db)
|
|
raise ProviderError("短信发送失败,请稍后重试")
|
|
return True, "验证码已发送", CODE_EXPIRE_MINUTES * 60
|
|
|
|
async def _issue_tokens(self, user_id: str, device_info: str = "") -> dict:
|
|
"""Create access + refresh token pair and add refresh token to session."""
|
|
refresh_str = generate_refresh_token_str()
|
|
token = RefreshToken(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
token=refresh_str,
|
|
expires_at=get_token_expires_at(),
|
|
created_at=datetime.now(timezone.utc),
|
|
device_info=device_info or None,
|
|
)
|
|
await repo.create_refresh_token(token, self._db)
|
|
access = create_access_token(data={"sub": user_id})
|
|
return {
|
|
"access_token": access,
|
|
"refresh_token": refresh_str,
|
|
"refresh_token_id": token.id,
|
|
}
|
|
|
|
async def _try_idempotent_refresh_within_grace(
|
|
self,
|
|
token_record: RefreshToken,
|
|
) -> dict | None:
|
|
"""Grace-window retry: return new access + existing replacement refresh."""
|
|
grace = settings.refresh_token_reuse_grace_seconds
|
|
if grace <= 0:
|
|
return None
|
|
if not token_record.replaced_by_token_id or token_record.rotated_at is None:
|
|
return None
|
|
now = utc_now()
|
|
rotated_at = _as_utc(token_record.rotated_at)
|
|
if rotated_at + timedelta(seconds=grace) <= now:
|
|
return None
|
|
replacement = await repo.get_refresh_token_by_id(
|
|
token_record.replaced_by_token_id, self._db
|
|
)
|
|
if replacement is None or replacement.is_revoked:
|
|
return None
|
|
if _as_utc(replacement.expires_at) < now:
|
|
return None
|
|
user = await repo.get_user_by_id(replacement.user_id, self._db)
|
|
if user is None:
|
|
return None
|
|
access = create_access_token(data={"sub": user.id})
|
|
return {
|
|
"access_token": access,
|
|
"refresh_token": replacement.token,
|
|
}
|
|
|
|
# ── public API ───────────────────────────────────────────
|
|
|
|
async def register(
|
|
self,
|
|
phone: str,
|
|
password: str,
|
|
nickname: str,
|
|
email: str | None = None,
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""Register new user. Returns {user, access_token, refresh_token}."""
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
|
|
if email and await repo.get_user_by_email(email, self._db):
|
|
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS")
|
|
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(password),
|
|
email=email,
|
|
nickname=nickname,
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
language_preference=_normalize_language(language),
|
|
)
|
|
async with transactional(self._db):
|
|
await _create_user_with_integrity_check(
|
|
self._db, user, phone_conflict="PHONE_EXISTS"
|
|
)
|
|
tokens = await self._issue_tokens(user_id)
|
|
await self._db.refresh(user)
|
|
return {"user": user, **_public_tokens(tokens)}
|
|
|
|
async def login(
|
|
self,
|
|
phone: str,
|
|
password: str,
|
|
device_info: str = "",
|
|
) -> dict:
|
|
"""Login with phone+password. Returns {user, access_token, refresh_token}."""
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
if not user or not verify_password(password, user.password_hash):
|
|
raise AuthError("手机号或密码错误", "INVALID_CREDENTIALS")
|
|
|
|
async with transactional(self._db):
|
|
tokens = await self._issue_tokens(user.id, device_info)
|
|
return {"user": user, **_public_tokens(tokens)}
|
|
|
|
async def login_with_google(
|
|
self,
|
|
id_token: str,
|
|
device_info: str = "",
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""Google OpenID Connect login. Auto-registers a user on first sign-in."""
|
|
identity = await asyncio.to_thread(verify_google_id_token, id_token)
|
|
openid = _google_openid(identity.subject)
|
|
|
|
user = await repo.get_user_by_openid(openid, self._db)
|
|
is_new_user = False
|
|
|
|
async with transactional(self._db):
|
|
if user is None:
|
|
user = await repo.get_user_by_email(identity.email, self._db)
|
|
if user is not None:
|
|
if user.openid and user.openid != openid:
|
|
raise AuthError("该邮箱已绑定其他第三方账号", "EMAIL_EXISTS")
|
|
user.openid = openid
|
|
if not user.email:
|
|
user.email = identity.email
|
|
if identity.picture and not user.avatar_url:
|
|
user.avatar_url = identity.picture
|
|
else:
|
|
is_new_user = True
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
phone=_google_internal_phone(identity.subject),
|
|
password_hash=hash_password(secrets.token_urlsafe(32)),
|
|
email=identity.email,
|
|
openid=openid,
|
|
nickname=_nickname_from_google_identity(identity),
|
|
avatar_url=identity.picture,
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
language_preference=_normalize_language(language),
|
|
)
|
|
await _create_user_with_integrity_check(
|
|
self._db, user, phone_conflict="PHONE_EXISTS"
|
|
)
|
|
|
|
tokens = await self._issue_tokens(user.id, device_info)
|
|
|
|
if is_new_user:
|
|
await self._db.refresh(user)
|
|
return {"user": user, "is_new_user": is_new_user, **_public_tokens(tokens)}
|
|
|
|
async def _revoke_all_active_tokens_in_session(self, user_id: str) -> int:
|
|
"""Revoke all active refresh tokens on the current session (no commit)."""
|
|
tokens = await repo.get_active_tokens_for_user(user_id, self._db)
|
|
for token in tokens:
|
|
token.is_revoked = True
|
|
return len(tokens)
|
|
|
|
async def refresh_tokens(
|
|
self,
|
|
refresh_token: str,
|
|
device_info: str = "",
|
|
) -> dict:
|
|
"""Rotate refresh token and issue a new access token pair."""
|
|
reuse_detected = False
|
|
async with transactional(self._db):
|
|
consumed = await repo.try_consume_refresh_token(refresh_token, self._db)
|
|
if consumed is not None:
|
|
user = await repo.get_user_by_id(consumed.user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
issued = await self._issue_tokens(
|
|
consumed.user_id,
|
|
device_info or (consumed.device_info or ""),
|
|
)
|
|
await repo.link_refresh_rotation(
|
|
consumed.id,
|
|
issued["refresh_token_id"],
|
|
utc_now(),
|
|
self._db,
|
|
)
|
|
return _public_tokens(issued)
|
|
|
|
token_record = await repo.get_refresh_token_by_token(
|
|
refresh_token, self._db
|
|
)
|
|
if not token_record:
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
if token_record.is_revoked:
|
|
idempotent = await self._try_idempotent_refresh_within_grace(
|
|
token_record
|
|
)
|
|
if idempotent is None and not token_record.replaced_by_token_id:
|
|
# Concurrent refresh may observe revoke before lineage commits.
|
|
token_record = (
|
|
await repo.get_refresh_token_by_token(refresh_token, self._db)
|
|
or token_record
|
|
)
|
|
idempotent = await self._try_idempotent_refresh_within_grace(
|
|
token_record
|
|
)
|
|
if idempotent is not None:
|
|
return idempotent
|
|
grace = settings.refresh_token_reuse_grace_seconds
|
|
rotated_at = token_record.rotated_at
|
|
still_in_grace = (
|
|
grace > 0
|
|
and rotated_at is not None
|
|
and _as_utc(rotated_at) + timedelta(seconds=grace) > utc_now()
|
|
)
|
|
if still_in_grace:
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
if (
|
|
grace > 0
|
|
and rotated_at is None
|
|
and not token_record.replaced_by_token_id
|
|
):
|
|
# Revoke visible but lineage not committed yet (concurrent rotation).
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
logger.bind(user_id=token_record.user_id).warning(
|
|
"Refresh token reuse detected (grace expired or no lineage)"
|
|
)
|
|
await self._revoke_all_active_tokens_in_session(token_record.user_id)
|
|
reuse_detected = True
|
|
elif _as_utc(token_record.expires_at) < utc_now():
|
|
raise AuthError("刷新令牌已过期", "TOKEN_EXPIRED")
|
|
else:
|
|
raise AuthError("无效的刷新令牌", "INVALID_TOKEN")
|
|
|
|
if reuse_detected:
|
|
raise AuthError("刷新令牌已失效,请重新登录", "REFRESH_TOKEN_REUSE")
|
|
|
|
async def logout(self, refresh_token: str, user_id: str) -> None:
|
|
"""Revoke a refresh token owned by the given user."""
|
|
token_record = await repo.get_refresh_token_by_token(refresh_token, self._db)
|
|
if token_record and token_record.user_id == user_id:
|
|
async with transactional(self._db):
|
|
token_record.is_revoked = True
|
|
|
|
async def logout_all(self, user_id: str) -> int:
|
|
"""Revoke all refresh tokens for user. Returns count revoked."""
|
|
async with transactional(self._db):
|
|
return await self._revoke_all_active_tokens_in_session(user_id)
|
|
|
|
async def login_with_sms(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""SMS login (auto-register if new). Returns {user, access_token, refresh_token, is_new_user}."""
|
|
precheck_error = await self._precheck_sms_code_for_purposes(
|
|
phone, code, ("login", "register")
|
|
)
|
|
if precheck_error is not None:
|
|
raise AuthError(precheck_error, "INVALID_SMS_CODE")
|
|
|
|
return await self._sms_login_after_code_verified(
|
|
phone,
|
|
code=code,
|
|
device_info=device_info,
|
|
nickname=nickname,
|
|
language=language,
|
|
)
|
|
|
|
async def _sms_login_after_code_verified(
|
|
self,
|
|
phone: str,
|
|
*,
|
|
code: str | None = None,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""SMS 校验通过后:在同一事务内原子消耗验证码、创建用户(如需)并签发令牌。
|
|
|
|
``language`` 仅在「新用户」分支下写入;命中已有用户时不覆盖偏好。
|
|
mock 路由传入 ``code=None`` 跳过验证码消耗。
|
|
"""
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
is_new_user = user is None
|
|
|
|
if is_new_user:
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(secrets.token_urlsafe(32)),
|
|
nickname=(nickname or "").strip(),
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
language_preference=_normalize_language(language),
|
|
)
|
|
|
|
async with transactional(self._db):
|
|
if code is not None:
|
|
await self._consume_sms_code_or_raise(
|
|
phone, code, "login", purposes=("login", "register")
|
|
)
|
|
if is_new_user:
|
|
try:
|
|
async with transactional_nested(self._db):
|
|
await repo.create_user(user, self._db)
|
|
await self._db.flush()
|
|
except IntegrityError as exc:
|
|
if is_user_phone_unique_violation(exc):
|
|
existing = await repo.get_user_by_phone(phone, self._db)
|
|
if existing is None:
|
|
raise
|
|
user = existing
|
|
is_new_user = False
|
|
else:
|
|
_raise_auth_error_from_user_integrity(
|
|
exc, phone_conflict="PHONE_EXISTS"
|
|
)
|
|
tokens = await self._issue_tokens(user.id, device_info)
|
|
if is_new_user:
|
|
await self._db.refresh(user)
|
|
|
|
return {"user": user, "is_new_user": is_new_user, **_public_tokens(tokens)}
|
|
|
|
async def mock_sms_login(
|
|
self,
|
|
phone: str,
|
|
device_info: str = "",
|
|
nickname: str | None = None,
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""跳过短信校验的登录/自动注册(仅由 mock 路由在配置允许时调用)。"""
|
|
return await self._sms_login_after_code_verified(
|
|
phone,
|
|
device_info=device_info,
|
|
nickname=nickname,
|
|
language=language,
|
|
)
|
|
|
|
async def register_with_sms(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
password: str,
|
|
nickname: str,
|
|
email: str | None = None,
|
|
device_info: str = "",
|
|
language: str | None = None,
|
|
) -> dict:
|
|
"""SMS register. Returns {user, access_token, refresh_token}."""
|
|
precheck_error = await self._precheck_sms_code(phone, code, "register")
|
|
if precheck_error is not None:
|
|
raise AuthError(precheck_error, "INVALID_SMS_CODE")
|
|
|
|
if await repo.get_user_by_phone(phone, self._db):
|
|
raise AuthError("该手机号已被注册", "PHONE_EXISTS")
|
|
|
|
if email and await repo.get_user_by_email(email, self._db):
|
|
raise AuthError("该邮箱已被注册", "EMAIL_EXISTS")
|
|
|
|
return await self._register_after_sms_verified(
|
|
phone=phone,
|
|
code=code,
|
|
password=password,
|
|
nickname=nickname,
|
|
email=email,
|
|
device_info=device_info,
|
|
language=language,
|
|
)
|
|
|
|
async def _register_after_sms_verified(
|
|
self,
|
|
*,
|
|
phone: str,
|
|
code: str,
|
|
password: str,
|
|
nickname: str,
|
|
email: str | None,
|
|
device_info: str,
|
|
language: str | None,
|
|
) -> dict:
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
phone=phone,
|
|
password_hash=hash_password(password),
|
|
email=email,
|
|
nickname=nickname,
|
|
subscription_type="free",
|
|
created_at=datetime.now(timezone.utc),
|
|
language_preference=_normalize_language(language),
|
|
)
|
|
async with transactional(self._db):
|
|
await self._consume_sms_code_or_raise(phone, code, "register")
|
|
await _create_user_with_integrity_check(
|
|
self._db, user, phone_conflict="PHONE_EXISTS"
|
|
)
|
|
tokens = await self._issue_tokens(user_id, device_info)
|
|
await self._db.refresh(user)
|
|
return {"user": user, **_public_tokens(tokens)}
|
|
|
|
async def reset_password(
|
|
self,
|
|
phone: str,
|
|
code: str,
|
|
new_password: str,
|
|
) -> None:
|
|
"""Reset password via SMS code."""
|
|
precheck_error = await self._precheck_sms_code(phone, code, "reset_password")
|
|
if precheck_error is not None:
|
|
raise AuthError(precheck_error, "INVALID_SMS_CODE")
|
|
|
|
user = await repo.get_user_by_phone(phone, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
async with transactional(self._db):
|
|
await self._consume_sms_code_or_raise(phone, code, "reset_password")
|
|
user.password_hash = hash_password(new_password)
|
|
|
|
async def change_password(
|
|
self,
|
|
user_id: str,
|
|
old_password: str,
|
|
new_password: str,
|
|
) -> None:
|
|
"""Change password (requires old password)."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
if not verify_password(old_password, user.password_hash):
|
|
raise AuthError("旧密码错误", "WRONG_PASSWORD")
|
|
|
|
async with transactional(self._db):
|
|
user.password_hash = hash_password(new_password)
|
|
|
|
async def change_phone(
|
|
self,
|
|
user_id: str,
|
|
new_phone: str,
|
|
code: str,
|
|
) -> User:
|
|
"""Change phone number via SMS code. Returns updated user."""
|
|
precheck_error = await self._precheck_sms_code(new_phone, code, "change_phone")
|
|
if precheck_error is not None:
|
|
raise AuthError(precheck_error, "INVALID_SMS_CODE")
|
|
|
|
existing = await repo.get_user_by_phone(new_phone, self._db)
|
|
if existing and existing.id != user_id:
|
|
raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN")
|
|
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
async with transactional(self._db):
|
|
await self._consume_sms_code_or_raise(new_phone, code, "change_phone")
|
|
user.phone = new_phone
|
|
try:
|
|
await self._db.flush()
|
|
except IntegrityError as exc:
|
|
_raise_auth_error_from_user_integrity(exc, phone_conflict="PHONE_TAKEN")
|
|
await self._db.refresh(user)
|
|
return user
|
|
|
|
async def update_nickname(self, user_id: str, nickname: str) -> User:
|
|
"""Update user nickname."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
|
|
async with transactional(self._db):
|
|
user.nickname = nickname.strip()
|
|
await self._db.refresh(user)
|
|
return user
|
|
|
|
async def update_avatar_url(self, user_id: str, avatar_url: str) -> User:
|
|
"""Update user avatar URL."""
|
|
user = await repo.get_user_by_id(user_id, self._db)
|
|
if not user:
|
|
raise AuthError("用户不存在", "USER_NOT_FOUND")
|
|
async with transactional(self._db):
|
|
user.avatar_url = avatar_url
|
|
await self._db.refresh(user)
|
|
return user
|
|
|
|
async def upload_avatar(
|
|
self,
|
|
user_id: str,
|
|
file_content: bytes,
|
|
content_type: str,
|
|
*,
|
|
old_avatar_url: str | None,
|
|
) -> User:
|
|
"""Validate, process, upload avatar to COS, and persist URL."""
|
|
allowed_types = ["image/jpeg", "image/png", "image/webp"]
|
|
if content_type not in allowed_types:
|
|
raise BadRequestError(
|
|
f"不支持的文件类型。仅支持: {', '.join(allowed_types)}"
|
|
)
|
|
if not file_content:
|
|
raise BadRequestError("文件内容为空")
|
|
if len(file_content) > 5 * 1024 * 1024:
|
|
raise BadRequestError("文件大小超过5MB限制")
|
|
if not (
|
|
(settings.tencent_secret_id or "").strip()
|
|
and (settings.tencent_secret_key or "").strip()
|
|
and (settings.tencent_cos_bucket or "").strip()
|
|
):
|
|
raise ServiceUnavailableError("头像存储服务未配置,请稍后再试")
|
|
|
|
jpeg_bytes = await _process_avatar_jpeg_async(file_content)
|
|
cos_key = f"avatars/{user_id}.jpg"
|
|
old_key = (
|
|
extract_cos_object_key_if_owned(old_avatar_url) if old_avatar_url else None
|
|
)
|
|
|
|
if not self._object_storage:
|
|
raise ServiceUnavailableError("头像存储服务未配置,请稍后再试")
|
|
|
|
try:
|
|
avatar_url = await asyncio.to_thread(
|
|
self._object_storage.upload, cos_key, jpeg_bytes, "image/jpeg"
|
|
)
|
|
except Exception as exc:
|
|
from app.core.logging import get_logger
|
|
|
|
get_logger(__name__).exception("COS 头像上传失败: {}", exc)
|
|
raise ServiceUnavailableError("头像存储暂时不可用,请稍后再试") from exc
|
|
|
|
try:
|
|
user = await self.update_avatar_url(user_id, avatar_url)
|
|
except Exception:
|
|
try:
|
|
await asyncio.to_thread(self._object_storage.delete, cos_key)
|
|
except Exception as cleanup_exc:
|
|
from app.core.logging import get_logger
|
|
|
|
get_logger(__name__).warning(
|
|
"头像 DB 写入失败后清理 COS 对象失败: key={} err={}",
|
|
cos_key,
|
|
cleanup_exc,
|
|
)
|
|
raise
|
|
|
|
if old_key and old_key != cos_key:
|
|
best_effort_delete_cos_object_for_url(old_avatar_url)
|
|
return user
|
|
|
|
|
|
def _is_valid_image_header(header: bytes) -> bool:
|
|
if header.startswith(b"\xff\xd8\xff"):
|
|
return True
|
|
if header.startswith(b"\x89PNG\r\n\x1a\n"):
|
|
return True
|
|
if header.startswith(b"RIFF") and b"WEBP" in header[:12]:
|
|
return True
|
|
return False
|
|
|
|
|
|
async def _process_avatar_jpeg_async(file_content: bytes) -> bytes:
|
|
try:
|
|
return await asyncio.to_thread(_process_avatar_jpeg, file_content)
|
|
except BadRequestError:
|
|
raise
|
|
except Exception as exc:
|
|
raise BadRequestError("无效的图片文件") from exc
|
|
|
|
|
|
def _process_avatar_jpeg(file_content: bytes) -> bytes:
|
|
image_bytes = io.BytesIO(file_content)
|
|
header = image_bytes.read(16)
|
|
image_bytes.seek(0)
|
|
if not _is_valid_image_header(header):
|
|
raise BadRequestError(f"无效的图片文件格式。文件头: {header[:12].hex()}")
|
|
|
|
image = Image.open(image_bytes)
|
|
if image.mode != "RGB":
|
|
image = image.convert("RGB")
|
|
|
|
width, height = image.size
|
|
size = min(width, height)
|
|
left = (width - size) // 2
|
|
top = (height - size) // 2
|
|
image = image.crop((left, top, left + size, top + size))
|
|
if size > 512:
|
|
image = image.resize((512, 512), Image.Resampling.LANCZOS)
|
|
|
|
jpeg_buffer = io.BytesIO()
|
|
image.save(jpeg_buffer, format="JPEG", quality=85, optimize=True)
|
|
return jpeg_buffer.getvalue()
|