"""SMS 验证码原子消耗与业务事务。""" from __future__ import annotations from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock import pytest from sqlalchemy.exc import IntegrityError from app.features.auth import repo from app.features.auth.models import SmsVerificationCode from app.features.auth.service import AuthError, AuthService def _sms_record(*, phone: str = "13800138000", code: str = "123456") -> SmsVerificationCode: now = datetime.now(timezone.utc) return SmsVerificationCode( id="code-1", phone=phone, code=code, purpose="login", is_used=False, is_expired=False, expires_at=now + timedelta(minutes=5), created_at=now, ) def _phone_integrity_error() -> IntegrityError: orig = MagicMock() orig.diag.constraint_name = "ix_users_phone" return IntegrityError("insert", {}, orig) def _db_mock() -> MagicMock: db = MagicMock() db.commit = AsyncMock() db.rollback = AsyncMock() db.refresh = AsyncMock() db.flush = AsyncMock() return db @pytest.mark.asyncio async def test_login_with_sms_consumes_code_in_same_transaction_as_tokens(monkeypatch) -> None: db = MagicMock() db.commit = AsyncMock() db.rollback = AsyncMock() db.refresh = AsyncMock() sms = MagicMock() svc = AuthService(db=db, sms=sms) record = _sms_record() async def fake_check(phone: str, code: str, purpose: str): if purpose == "login": return record, "验证成功" return None, "验证码不存在或已使用" issue_calls: list[str] = [] async def fake_issue_tokens(user_id: str, device_info: str = ""): issue_calls.append(user_id) return {"access_token": "a", "refresh_token": "r"} monkeypatch.setattr(svc, "_check_sms_code", fake_check) monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens) monkeypatch.setattr( repo, "get_user_by_phone", AsyncMock(return_value=MagicMock(id="user-1")), ) monkeypatch.setattr( repo, "try_consume_verification_code", AsyncMock(return_value=record), ) result = await svc.login_with_sms("13800138000", "123456") assert result["access_token"] == "a" db.commit.assert_awaited_once() assert issue_calls == ["user-1"] repo.try_consume_verification_code.assert_awaited() @pytest.mark.asyncio async def test_login_with_sms_does_not_issue_tokens_when_consume_fails(monkeypatch) -> None: db = MagicMock() db.commit = AsyncMock() db.rollback = AsyncMock() sms = MagicMock() svc = AuthService(db=db, sms=sms) monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(_sms_record(), "验证成功")), ) monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) monkeypatch.setattr( repo, "get_user_by_phone", AsyncMock(return_value=MagicMock(id="user-1")), ) monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None)) with pytest.raises(AuthError): await svc.login_with_sms("13800138000", "123456") db.rollback.assert_awaited_once() db.commit.assert_not_awaited() svc._issue_tokens.assert_not_awaited() @pytest.mark.asyncio async def test_login_with_sms_does_not_issue_tokens_when_token_issue_fails(monkeypatch) -> None: db = MagicMock() db.commit = AsyncMock() db.rollback = AsyncMock() sms = MagicMock() svc = AuthService(db=db, sms=sms) monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(_sms_record(), "验证成功")), ) async def failing_issue_tokens(user_id: str, device_info: str = ""): raise RuntimeError("token store down") monkeypatch.setattr(svc, "_issue_tokens", failing_issue_tokens) monkeypatch.setattr( repo, "get_user_by_phone", AsyncMock(return_value=MagicMock(id="user-1")), ) monkeypatch.setattr( repo, "try_consume_verification_code", AsyncMock(return_value=_sms_record()), ) with pytest.raises(RuntimeError, match="token store down"): await svc.login_with_sms("13800138000", "123456") db.rollback.assert_awaited_once() db.commit.assert_not_awaited() @pytest.mark.asyncio async def test_register_with_sms_uses_atomic_consume(monkeypatch) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) record = _sms_record() monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(record, "验证成功")), ) monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "create_user", AsyncMock()) monkeypatch.setattr( svc, "_issue_tokens", AsyncMock(return_value={"access_token": "a", "refresh_token": "r"}), ) consume = AsyncMock(return_value=record) monkeypatch.setattr(repo, "try_consume_verification_code", consume) await svc.register_with_sms( "13800138000", "123456", "password1", "nick", ) consume.assert_awaited_once_with("13800138000", "123456", "register", db) db.commit.assert_awaited_once() db.flush.assert_awaited_once() @pytest.mark.asyncio async def test_register_with_sms_maps_phone_integrity_error_to_phone_exists( monkeypatch, ) -> None: db = _db_mock() db.flush = AsyncMock(side_effect=_phone_integrity_error()) sms = MagicMock() svc = AuthService(db=db, sms=sms) record = _sms_record() record.purpose = "register" monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(record, "验证成功")), ) monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "create_user", AsyncMock()) monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record)) monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) with pytest.raises(AuthError) as exc_info: await svc.register_with_sms( "13800138000", "123456", "password1", "nick", ) assert exc_info.value.code == "PHONE_EXISTS" db.rollback.assert_awaited_once() db.commit.assert_not_awaited() svc._issue_tokens.assert_not_awaited() @pytest.mark.asyncio async def test_login_with_sms_recovers_when_concurrent_registration_wins( monkeypatch, ) -> None: db = _db_mock() db.flush = AsyncMock(side_effect=_phone_integrity_error()) sms = MagicMock() svc = AuthService(db=db, sms=sms) record = _sms_record() existing_user = MagicMock(id="existing-user") monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(record, "验证成功")), ) monkeypatch.setattr( repo, "get_user_by_phone", AsyncMock(side_effect=[None, existing_user]), ) monkeypatch.setattr(repo, "create_user", AsyncMock()) monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record)) issue_calls: list[str] = [] async def fake_issue_tokens(user_id: str, device_info: str = ""): issue_calls.append(user_id) return {"access_token": "a", "refresh_token": "r"} monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens) result = await svc.login_with_sms("13800138000", "123456") assert result["access_token"] == "a" assert result["is_new_user"] is False assert issue_calls == ["existing-user"] db.commit.assert_awaited_once() @pytest.mark.asyncio async def test_change_phone_maps_phone_integrity_error_to_phone_taken(monkeypatch) -> None: db = _db_mock() db.flush = AsyncMock(side_effect=_phone_integrity_error()) sms = MagicMock() svc = AuthService(db=db, sms=sms) record = _sms_record() record.purpose = "change_phone" user = MagicMock(id="user-1", phone="13800138001") monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(record, "验证成功")), ) monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=user)) monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record)) with pytest.raises(AuthError) as exc_info: await svc.change_phone("user-1", "13800138000", "123456") assert exc_info.value.code == "PHONE_TAKEN" db.rollback.assert_awaited_once() db.commit.assert_not_awaited() @pytest.mark.asyncio async def test_register_with_sms_consume_race_returns_fresh_invalid_message( monkeypatch, ) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None)) monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(None, "验证码不存在或已使用")), ) monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None)) with pytest.raises(AuthError) as exc_info: await svc.register_with_sms( "13800138000", "123456", "password1", "nick", ) assert exc_info.value.code == "INVALID_SMS_CODE" assert exc_info.value.message != "验证成功" assert exc_info.value.message == "验证码不存在或已使用" @pytest.mark.asyncio async def test_reset_password_consume_race_returns_fresh_invalid_message( monkeypatch, ) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) user = MagicMock(id="user-1") monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None)) monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(None, "验证码已过期")), ) monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=user)) monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None)) with pytest.raises(AuthError) as exc_info: await svc.reset_password("13800138000", "123456", "newpass1") assert exc_info.value.code == "INVALID_SMS_CODE" assert exc_info.value.message != "验证成功" assert exc_info.value.message == "验证码已过期" @pytest.mark.asyncio async def test_change_phone_consume_race_returns_fresh_invalid_message( monkeypatch, ) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) user = MagicMock(id="user-1", phone="13800138001") monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None)) monkeypatch.setattr( svc, "_check_sms_code", AsyncMock(return_value=(None, "验证码不存在或已使用")), ) monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=user)) monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None)) with pytest.raises(AuthError) as exc_info: await svc.change_phone("user-1", "13800138000", "123456") assert exc_info.value.code == "INVALID_SMS_CODE" assert exc_info.value.message != "验证成功" assert exc_info.value.message == "验证码不存在或已使用"