96 lines
3.3 KiB
Python
96 lines
3.3 KiB
Python
|
|
"""SMS 发送失败后事务回滚,不应留下验证码记录阻塞立即重试。"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from unittest.mock import AsyncMock, MagicMock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from app.core.errors import ProviderError, RateLimitedError
|
||
|
|
from app.features.auth import repo
|
||
|
|
from app.features.auth.models import SmsVerificationCode
|
||
|
|
from app.features.auth import service as auth_service_mod
|
||
|
|
from app.features.auth.service import AuthService, CODE_EXPIRE_MINUTES
|
||
|
|
|
||
|
|
|
||
|
|
def _make_service(*, sms_send_ok: bool) -> AuthService:
|
||
|
|
db = MagicMock()
|
||
|
|
db.commit = AsyncMock(return_value=None)
|
||
|
|
db.rollback = AsyncMock(return_value=None)
|
||
|
|
sms = MagicMock()
|
||
|
|
sms.send_verification_code = MagicMock(return_value=sms_send_ok)
|
||
|
|
return AuthService(db=db, sms=sms)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_send_sms_after_provider_failure_not_rate_limited(monkeypatch) -> None:
|
||
|
|
phone = "13800138000"
|
||
|
|
|
||
|
|
async def fake_get_user_by_phone(p: str, db):
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def fake_get_recent_code_for_rate_limit(p: str, db):
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def fake_create_verification_code(record, db):
|
||
|
|
record.id = "new-record"
|
||
|
|
|
||
|
|
expire_calls: list[str] = []
|
||
|
|
|
||
|
|
async def fake_mark_expired(code_id, db):
|
||
|
|
expire_calls.append(code_id)
|
||
|
|
|
||
|
|
monkeypatch.setattr(repo, "get_user_by_phone", fake_get_user_by_phone)
|
||
|
|
monkeypatch.setattr(
|
||
|
|
repo, "get_recent_code_for_rate_limit", fake_get_recent_code_for_rate_limit
|
||
|
|
)
|
||
|
|
monkeypatch.setattr(repo, "create_verification_code", fake_create_verification_code)
|
||
|
|
monkeypatch.setattr(
|
||
|
|
repo, "mark_verification_code_expired", fake_mark_expired
|
||
|
|
)
|
||
|
|
monkeypatch.setattr(auth_service_mod, "_sms_is_configured", lambda: True)
|
||
|
|
|
||
|
|
svc_fail = _make_service(sms_send_ok=False)
|
||
|
|
with pytest.raises(ProviderError) as exc_info:
|
||
|
|
await svc_fail.send_sms_code(phone, "register")
|
||
|
|
assert "失败" in exc_info.value.message
|
||
|
|
assert exc_info.value.error_code == "PROVIDER_ERROR"
|
||
|
|
assert exc_info.value.status_code == 502
|
||
|
|
assert expire_calls == ["new-record"]
|
||
|
|
svc_fail._sms.send_verification_code.assert_called_once()
|
||
|
|
|
||
|
|
svc_ok = _make_service(sms_send_ok=True)
|
||
|
|
success2, message2, expires_in2 = await svc_ok.send_sms_code(phone, "register")
|
||
|
|
assert success2 is True
|
||
|
|
assert message2 == "验证码已发送"
|
||
|
|
assert expires_in2 == CODE_EXPIRE_MINUTES * 60
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_send_sms_rate_limited_raises_rate_limited_error(monkeypatch) -> None:
|
||
|
|
phone = "13800138000"
|
||
|
|
now = datetime.now(timezone.utc)
|
||
|
|
recent = SmsVerificationCode(
|
||
|
|
id="recent-1",
|
||
|
|
phone=phone,
|
||
|
|
code="111111",
|
||
|
|
purpose="register",
|
||
|
|
expires_at=now,
|
||
|
|
created_at=now,
|
||
|
|
)
|
||
|
|
|
||
|
|
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
|
||
|
|
monkeypatch.setattr(
|
||
|
|
repo, "get_recent_code_for_rate_limit", AsyncMock(return_value=recent)
|
||
|
|
)
|
||
|
|
monkeypatch.setattr(auth_service_mod, "_sms_is_configured", lambda: True)
|
||
|
|
|
||
|
|
svc = _make_service(sms_send_ok=True)
|
||
|
|
with pytest.raises(RateLimitedError) as exc_info:
|
||
|
|
await svc.send_sms_code(phone, "register")
|
||
|
|
|
||
|
|
assert "频繁" in exc_info.value.message
|
||
|
|
assert exc_info.value.error_code == "RATE_LIMITED"
|
||
|
|
assert exc_info.value.status_code == 429
|