配置 SSOT(TOML + .env) 统一错误契约 Auth 与事务边界 Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client 可观测性(OpenTelemetry + LGTM)
372 lines
12 KiB
Python
372 lines
12 KiB
Python
"""SMS 验证码原子消耗与业务事务。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from app.features.auth import repo
|
|
from app.features.auth.models import SmsVerificationCode
|
|
from app.features.auth.service import AuthError, AuthService
|
|
|
|
|
|
def _sms_record(*, phone: str = "13800138000", code: str = "123456") -> SmsVerificationCode:
|
|
now = datetime.now(timezone.utc)
|
|
return SmsVerificationCode(
|
|
id="code-1",
|
|
phone=phone,
|
|
code=code,
|
|
purpose="login",
|
|
is_used=False,
|
|
is_expired=False,
|
|
expires_at=now + timedelta(minutes=5),
|
|
created_at=now,
|
|
)
|
|
|
|
|
|
def _phone_integrity_error() -> IntegrityError:
|
|
orig = MagicMock()
|
|
orig.diag.constraint_name = "ix_users_phone"
|
|
return IntegrityError("insert", {}, orig)
|
|
|
|
|
|
def _db_mock() -> MagicMock:
|
|
db = MagicMock()
|
|
db.commit = AsyncMock()
|
|
db.rollback = AsyncMock()
|
|
db.refresh = AsyncMock()
|
|
db.flush = AsyncMock()
|
|
return db
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_with_sms_consumes_code_in_same_transaction_as_tokens(monkeypatch) -> None:
|
|
db = MagicMock()
|
|
db.commit = AsyncMock()
|
|
db.rollback = AsyncMock()
|
|
db.refresh = AsyncMock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
record = _sms_record()
|
|
|
|
async def fake_check(phone: str, code: str, purpose: str):
|
|
if purpose == "login":
|
|
return record, "验证成功"
|
|
return None, "验证码不存在或已使用"
|
|
|
|
issue_calls: list[str] = []
|
|
|
|
async def fake_issue_tokens(user_id: str, device_info: str = ""):
|
|
issue_calls.append(user_id)
|
|
return {"access_token": "a", "refresh_token": "r"}
|
|
|
|
monkeypatch.setattr(svc, "_check_sms_code", fake_check)
|
|
monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens)
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_user_by_phone",
|
|
AsyncMock(return_value=MagicMock(id="user-1")),
|
|
)
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"try_consume_verification_code",
|
|
AsyncMock(return_value=record),
|
|
)
|
|
|
|
result = await svc.login_with_sms("13800138000", "123456")
|
|
|
|
assert result["access_token"] == "a"
|
|
db.commit.assert_awaited_once()
|
|
assert issue_calls == ["user-1"]
|
|
repo.try_consume_verification_code.assert_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_with_sms_does_not_issue_tokens_when_consume_fails(monkeypatch) -> None:
|
|
db = MagicMock()
|
|
db.commit = AsyncMock()
|
|
db.rollback = AsyncMock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(_sms_record(), "验证成功")),
|
|
)
|
|
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_user_by_phone",
|
|
AsyncMock(return_value=MagicMock(id="user-1")),
|
|
)
|
|
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None))
|
|
|
|
with pytest.raises(AuthError):
|
|
await svc.login_with_sms("13800138000", "123456")
|
|
|
|
db.rollback.assert_awaited_once()
|
|
db.commit.assert_not_awaited()
|
|
svc._issue_tokens.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_with_sms_does_not_issue_tokens_when_token_issue_fails(monkeypatch) -> None:
|
|
db = MagicMock()
|
|
db.commit = AsyncMock()
|
|
db.rollback = AsyncMock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(_sms_record(), "验证成功")),
|
|
)
|
|
|
|
async def failing_issue_tokens(user_id: str, device_info: str = ""):
|
|
raise RuntimeError("token store down")
|
|
|
|
monkeypatch.setattr(svc, "_issue_tokens", failing_issue_tokens)
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_user_by_phone",
|
|
AsyncMock(return_value=MagicMock(id="user-1")),
|
|
)
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"try_consume_verification_code",
|
|
AsyncMock(return_value=_sms_record()),
|
|
)
|
|
|
|
with pytest.raises(RuntimeError, match="token store down"):
|
|
await svc.login_with_sms("13800138000", "123456")
|
|
|
|
db.rollback.assert_awaited_once()
|
|
db.commit.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_with_sms_uses_atomic_consume(monkeypatch) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
record = _sms_record()
|
|
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(record, "验证成功")),
|
|
)
|
|
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "create_user", AsyncMock())
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_issue_tokens",
|
|
AsyncMock(return_value={"access_token": "a", "refresh_token": "r"}),
|
|
)
|
|
consume = AsyncMock(return_value=record)
|
|
monkeypatch.setattr(repo, "try_consume_verification_code", consume)
|
|
|
|
await svc.register_with_sms(
|
|
"13800138000",
|
|
"123456",
|
|
"password1",
|
|
"nick",
|
|
)
|
|
|
|
consume.assert_awaited_once_with("13800138000", "123456", "register", db)
|
|
db.commit.assert_awaited_once()
|
|
db.flush.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_with_sms_maps_phone_integrity_error_to_phone_exists(
|
|
monkeypatch,
|
|
) -> None:
|
|
db = _db_mock()
|
|
db.flush = AsyncMock(side_effect=_phone_integrity_error())
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
record = _sms_record()
|
|
record.purpose = "register"
|
|
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(record, "验证成功")),
|
|
)
|
|
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "create_user", AsyncMock())
|
|
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record))
|
|
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.register_with_sms(
|
|
"13800138000",
|
|
"123456",
|
|
"password1",
|
|
"nick",
|
|
)
|
|
|
|
assert exc_info.value.code == "PHONE_EXISTS"
|
|
db.rollback.assert_awaited_once()
|
|
db.commit.assert_not_awaited()
|
|
svc._issue_tokens.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_with_sms_recovers_when_concurrent_registration_wins(
|
|
monkeypatch,
|
|
) -> None:
|
|
db = _db_mock()
|
|
db.flush = AsyncMock(side_effect=_phone_integrity_error())
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
record = _sms_record()
|
|
existing_user = MagicMock(id="existing-user")
|
|
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(record, "验证成功")),
|
|
)
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_user_by_phone",
|
|
AsyncMock(side_effect=[None, existing_user]),
|
|
)
|
|
monkeypatch.setattr(repo, "create_user", AsyncMock())
|
|
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record))
|
|
|
|
issue_calls: list[str] = []
|
|
|
|
async def fake_issue_tokens(user_id: str, device_info: str = ""):
|
|
issue_calls.append(user_id)
|
|
return {"access_token": "a", "refresh_token": "r"}
|
|
|
|
monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens)
|
|
|
|
result = await svc.login_with_sms("13800138000", "123456")
|
|
|
|
assert result["access_token"] == "a"
|
|
assert result["is_new_user"] is False
|
|
assert issue_calls == ["existing-user"]
|
|
db.commit.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_change_phone_maps_phone_integrity_error_to_phone_taken(monkeypatch) -> None:
|
|
db = _db_mock()
|
|
db.flush = AsyncMock(side_effect=_phone_integrity_error())
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
record = _sms_record()
|
|
record.purpose = "change_phone"
|
|
user = MagicMock(id="user-1", phone="13800138001")
|
|
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(record, "验证成功")),
|
|
)
|
|
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=user))
|
|
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record))
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.change_phone("user-1", "13800138000", "123456")
|
|
|
|
assert exc_info.value.code == "PHONE_TAKEN"
|
|
db.rollback.assert_awaited_once()
|
|
db.commit.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_with_sms_consume_race_returns_fresh_invalid_message(
|
|
monkeypatch,
|
|
) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
|
|
monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(None, "验证码不存在或已使用")),
|
|
)
|
|
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None))
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.register_with_sms(
|
|
"13800138000",
|
|
"123456",
|
|
"password1",
|
|
"nick",
|
|
)
|
|
|
|
assert exc_info.value.code == "INVALID_SMS_CODE"
|
|
assert exc_info.value.message != "验证成功"
|
|
assert exc_info.value.message == "验证码不存在或已使用"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_password_consume_race_returns_fresh_invalid_message(
|
|
monkeypatch,
|
|
) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
user = MagicMock(id="user-1")
|
|
|
|
monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(None, "验证码已过期")),
|
|
)
|
|
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=user))
|
|
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None))
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.reset_password("13800138000", "123456", "newpass1")
|
|
|
|
assert exc_info.value.code == "INVALID_SMS_CODE"
|
|
assert exc_info.value.message != "验证成功"
|
|
assert exc_info.value.message == "验证码已过期"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_change_phone_consume_race_returns_fresh_invalid_message(
|
|
monkeypatch,
|
|
) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
user = MagicMock(id="user-1", phone="13800138001")
|
|
|
|
monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_check_sms_code",
|
|
AsyncMock(return_value=(None, "验证码不存在或已使用")),
|
|
)
|
|
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=user))
|
|
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None))
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.change_phone("user-1", "13800138000", "123456")
|
|
|
|
assert exc_info.value.code == "INVALID_SMS_CODE"
|
|
assert exc_info.value.message != "验证成功"
|
|
assert exc_info.value.message == "验证码不存在或已使用"
|