"""Refresh token rotation and reuse detection.""" from __future__ import annotations from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock import pytest from app.features.auth import repo from app.features.auth.models import RefreshToken from app.features.auth.service import AuthError, AuthService def _refresh_record( *, token: str = "old-refresh", user_id: str = "user-1", is_revoked: bool = False, expired: bool = False, replaced_by_token_id: str | None = None, rotated_at: datetime | None = None, ) -> MagicMock: now = datetime.now(timezone.utc) record = MagicMock(spec=RefreshToken) record.id = "rt-1" record.user_id = user_id record.token = token record.expires_at = ( now - timedelta(days=1) if expired else now + timedelta(days=30) ) record.created_at = now record.is_revoked = is_revoked record.device_info = "iphone" record.replaced_by_token_id = replaced_by_token_id record.rotated_at = rotated_at return record def _db_mock() -> MagicMock: db = MagicMock() db.commit = AsyncMock() db.rollback = AsyncMock() db.refresh = AsyncMock() db.flush = AsyncMock() return db @pytest.mark.asyncio async def test_refresh_rotates_token_in_transaction(monkeypatch) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) consumed = _refresh_record() monkeypatch.setattr( repo, "try_consume_refresh_token", AsyncMock(return_value=consumed), ) monkeypatch.setattr( repo, "get_user_by_id", AsyncMock(return_value=MagicMock(id="user-1")), ) monkeypatch.setattr( svc, "_issue_tokens", AsyncMock( return_value={ "access_token": "new-access", "refresh_token": "new-refresh", "refresh_token_id": "rt-2", } ), ) monkeypatch.setattr(repo, "link_refresh_rotation", AsyncMock()) result = await svc.refresh_tokens("old-refresh") assert result["access_token"] == "new-access" assert result["refresh_token"] == "new-refresh" assert result["refresh_token"] != "old-refresh" db.commit.assert_awaited_once() svc._issue_tokens.assert_awaited_once_with("user-1", "iphone") repo.try_consume_refresh_token.assert_awaited_once_with("old-refresh", db) repo.link_refresh_rotation.assert_awaited_once() assert repo.link_refresh_rotation.await_args.args[0] == "rt-1" assert repo.link_refresh_rotation.await_args.args[1] == "rt-2" @pytest.mark.asyncio async def test_refresh_grace_reuse_returns_idempotent_tokens(monkeypatch) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) now = datetime.now(timezone.utc) revoked = _refresh_record( is_revoked=True, replaced_by_token_id="rt-2", rotated_at=now, ) replacement = _refresh_record(token="new-refresh", is_revoked=False) replacement.id = "rt-2" monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None)) monkeypatch.setattr( repo, "get_refresh_token_by_token", AsyncMock(return_value=revoked), ) monkeypatch.setattr( repo, "get_refresh_token_by_id", AsyncMock(return_value=replacement), ) monkeypatch.setattr( repo, "get_user_by_id", AsyncMock(return_value=MagicMock(id="user-1")), ) monkeypatch.setattr( svc, "_revoke_all_active_tokens_in_session", AsyncMock(return_value=2), ) monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) result = await svc.refresh_tokens("old-refresh") assert result["refresh_token"] == "new-refresh" assert result["access_token"] svc._revoke_all_active_tokens_in_session.assert_not_awaited() svc._issue_tokens.assert_not_awaited() db.commit.assert_awaited_once() @pytest.mark.asyncio async def test_refresh_reuse_revokes_all_sessions(monkeypatch) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) now = datetime.now(timezone.utc) revoked = _refresh_record( is_revoked=True, replaced_by_token_id="rt-2", rotated_at=now - timedelta(seconds=120), ) monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None)) monkeypatch.setattr( repo, "get_refresh_token_by_token", AsyncMock(return_value=revoked), ) monkeypatch.setattr( svc, "_revoke_all_active_tokens_in_session", AsyncMock(return_value=2), ) monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) with pytest.raises(AuthError) as exc_info: await svc.refresh_tokens("old-refresh") assert exc_info.value.code == "REFRESH_TOKEN_REUSE" assert exc_info.value.error_code == "REFRESH_TOKEN_REUSE" svc._revoke_all_active_tokens_in_session.assert_awaited_once_with("user-1") db.commit.assert_awaited_once() svc._issue_tokens.assert_not_awaited() @pytest.mark.asyncio async def test_refresh_unknown_token(monkeypatch) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None)) monkeypatch.setattr(repo, "get_refresh_token_by_token", AsyncMock(return_value=None)) monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) with pytest.raises(AuthError) as exc_info: await svc.refresh_tokens("missing") assert exc_info.value.code == "INVALID_TOKEN" db.rollback.assert_awaited_once() svc._issue_tokens.assert_not_awaited() @pytest.mark.asyncio async def test_refresh_expired_token(monkeypatch) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) expired = _refresh_record(expired=True) monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None)) monkeypatch.setattr( repo, "get_refresh_token_by_token", AsyncMock(return_value=expired), ) monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) with pytest.raises(AuthError) as exc_info: await svc.refresh_tokens("old-refresh") assert exc_info.value.code == "TOKEN_EXPIRED" db.rollback.assert_awaited_once() svc._issue_tokens.assert_not_awaited() @pytest.mark.asyncio async def test_refresh_user_deleted_rolls_back_consume(monkeypatch) -> None: db = _db_mock() sms = MagicMock() svc = AuthService(db=db, sms=sms) consumed = _refresh_record() monkeypatch.setattr( repo, "try_consume_refresh_token", AsyncMock(return_value=consumed), ) monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=None)) monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) with pytest.raises(AuthError) as exc_info: await svc.refresh_tokens("old-refresh") assert exc_info.value.code == "USER_NOT_FOUND" db.rollback.assert_awaited_once() svc._issue_tokens.assert_not_awaited()