Files
life-echo/api/tests/test_auth_sms_verify_transactional.py

372 lines
12 KiB
Python
Raw Normal View History

"""SMS 验证码原子消耗与业务事务。"""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
from sqlalchemy.exc import IntegrityError
from app.features.auth import repo
from app.features.auth.models import SmsVerificationCode
from app.features.auth.service import AuthError, AuthService
def _sms_record(*, phone: str = "13800138000", code: str = "123456") -> SmsVerificationCode:
now = datetime.now(timezone.utc)
return SmsVerificationCode(
id="code-1",
phone=phone,
code=code,
purpose="login",
is_used=False,
is_expired=False,
expires_at=now + timedelta(minutes=5),
created_at=now,
)
def _phone_integrity_error() -> IntegrityError:
orig = MagicMock()
orig.diag.constraint_name = "ix_users_phone"
return IntegrityError("insert", {}, orig)
def _db_mock() -> MagicMock:
db = MagicMock()
db.commit = AsyncMock()
db.rollback = AsyncMock()
db.refresh = AsyncMock()
db.flush = AsyncMock()
return db
@pytest.mark.asyncio
async def test_login_with_sms_consumes_code_in_same_transaction_as_tokens(monkeypatch) -> None:
db = MagicMock()
db.commit = AsyncMock()
db.rollback = AsyncMock()
db.refresh = AsyncMock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
record = _sms_record()
async def fake_check(phone: str, code: str, purpose: str):
if purpose == "login":
return record, "验证成功"
return None, "验证码不存在或已使用"
issue_calls: list[str] = []
async def fake_issue_tokens(user_id: str, device_info: str = ""):
issue_calls.append(user_id)
return {"access_token": "a", "refresh_token": "r"}
monkeypatch.setattr(svc, "_check_sms_code", fake_check)
monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens)
monkeypatch.setattr(
repo,
"get_user_by_phone",
AsyncMock(return_value=MagicMock(id="user-1")),
)
monkeypatch.setattr(
repo,
"try_consume_verification_code",
AsyncMock(return_value=record),
)
result = await svc.login_with_sms("13800138000", "123456")
assert result["access_token"] == "a"
db.commit.assert_awaited_once()
assert issue_calls == ["user-1"]
repo.try_consume_verification_code.assert_awaited()
@pytest.mark.asyncio
async def test_login_with_sms_does_not_issue_tokens_when_consume_fails(monkeypatch) -> None:
db = MagicMock()
db.commit = AsyncMock()
db.rollback = AsyncMock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(_sms_record(), "验证成功")),
)
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
monkeypatch.setattr(
repo,
"get_user_by_phone",
AsyncMock(return_value=MagicMock(id="user-1")),
)
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None))
with pytest.raises(AuthError):
await svc.login_with_sms("13800138000", "123456")
db.rollback.assert_awaited_once()
db.commit.assert_not_awaited()
svc._issue_tokens.assert_not_awaited()
@pytest.mark.asyncio
async def test_login_with_sms_does_not_issue_tokens_when_token_issue_fails(monkeypatch) -> None:
db = MagicMock()
db.commit = AsyncMock()
db.rollback = AsyncMock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(_sms_record(), "验证成功")),
)
async def failing_issue_tokens(user_id: str, device_info: str = ""):
raise RuntimeError("token store down")
monkeypatch.setattr(svc, "_issue_tokens", failing_issue_tokens)
monkeypatch.setattr(
repo,
"get_user_by_phone",
AsyncMock(return_value=MagicMock(id="user-1")),
)
monkeypatch.setattr(
repo,
"try_consume_verification_code",
AsyncMock(return_value=_sms_record()),
)
with pytest.raises(RuntimeError, match="token store down"):
await svc.login_with_sms("13800138000", "123456")
db.rollback.assert_awaited_once()
db.commit.assert_not_awaited()
@pytest.mark.asyncio
async def test_register_with_sms_uses_atomic_consume(monkeypatch) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
record = _sms_record()
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(record, "验证成功")),
)
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "create_user", AsyncMock())
monkeypatch.setattr(
svc,
"_issue_tokens",
AsyncMock(return_value={"access_token": "a", "refresh_token": "r"}),
)
consume = AsyncMock(return_value=record)
monkeypatch.setattr(repo, "try_consume_verification_code", consume)
await svc.register_with_sms(
"13800138000",
"123456",
"password1",
"nick",
)
consume.assert_awaited_once_with("13800138000", "123456", "register", db)
db.commit.assert_awaited_once()
db.flush.assert_awaited_once()
@pytest.mark.asyncio
async def test_register_with_sms_maps_phone_integrity_error_to_phone_exists(
monkeypatch,
) -> None:
db = _db_mock()
db.flush = AsyncMock(side_effect=_phone_integrity_error())
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
record = _sms_record()
record.purpose = "register"
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(record, "验证成功")),
)
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "create_user", AsyncMock())
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record))
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
with pytest.raises(AuthError) as exc_info:
await svc.register_with_sms(
"13800138000",
"123456",
"password1",
"nick",
)
assert exc_info.value.code == "PHONE_EXISTS"
db.rollback.assert_awaited_once()
db.commit.assert_not_awaited()
svc._issue_tokens.assert_not_awaited()
@pytest.mark.asyncio
async def test_login_with_sms_recovers_when_concurrent_registration_wins(
monkeypatch,
) -> None:
db = _db_mock()
db.flush = AsyncMock(side_effect=_phone_integrity_error())
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
record = _sms_record()
existing_user = MagicMock(id="existing-user")
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(record, "验证成功")),
)
monkeypatch.setattr(
repo,
"get_user_by_phone",
AsyncMock(side_effect=[None, existing_user]),
)
monkeypatch.setattr(repo, "create_user", AsyncMock())
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record))
issue_calls: list[str] = []
async def fake_issue_tokens(user_id: str, device_info: str = ""):
issue_calls.append(user_id)
return {"access_token": "a", "refresh_token": "r"}
monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens)
result = await svc.login_with_sms("13800138000", "123456")
assert result["access_token"] == "a"
assert result["is_new_user"] is False
assert issue_calls == ["existing-user"]
db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_change_phone_maps_phone_integrity_error_to_phone_taken(monkeypatch) -> None:
db = _db_mock()
db.flush = AsyncMock(side_effect=_phone_integrity_error())
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
record = _sms_record()
record.purpose = "change_phone"
user = MagicMock(id="user-1", phone="13800138001")
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(record, "验证成功")),
)
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=user))
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record))
with pytest.raises(AuthError) as exc_info:
await svc.change_phone("user-1", "13800138000", "123456")
assert exc_info.value.code == "PHONE_TAKEN"
db.rollback.assert_awaited_once()
db.commit.assert_not_awaited()
@pytest.mark.asyncio
async def test_register_with_sms_consume_race_returns_fresh_invalid_message(
monkeypatch,
) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None))
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(None, "验证码不存在或已使用")),
)
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None))
with pytest.raises(AuthError) as exc_info:
await svc.register_with_sms(
"13800138000",
"123456",
"password1",
"nick",
)
assert exc_info.value.code == "INVALID_SMS_CODE"
assert exc_info.value.message != "验证成功"
assert exc_info.value.message == "验证码不存在或已使用"
@pytest.mark.asyncio
async def test_reset_password_consume_race_returns_fresh_invalid_message(
monkeypatch,
) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
user = MagicMock(id="user-1")
monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None))
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(None, "验证码已过期")),
)
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=user))
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None))
with pytest.raises(AuthError) as exc_info:
await svc.reset_password("13800138000", "123456", "newpass1")
assert exc_info.value.code == "INVALID_SMS_CODE"
assert exc_info.value.message != "验证成功"
assert exc_info.value.message == "验证码已过期"
@pytest.mark.asyncio
async def test_change_phone_consume_race_returns_fresh_invalid_message(
monkeypatch,
) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
user = MagicMock(id="user-1", phone="13800138001")
monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None))
monkeypatch.setattr(
svc,
"_check_sms_code",
AsyncMock(return_value=(None, "验证码不存在或已使用")),
)
monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=user))
monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None))
with pytest.raises(AuthError) as exc_info:
await svc.change_phone("user-1", "13800138000", "123456")
assert exc_info.value.code == "INVALID_SMS_CODE"
assert exc_info.value.message != "验证成功"
assert exc_info.value.message == "验证码不存在或已使用"