Files
life-echo/api/tests/test_auth_refresh_rotation.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

235 lines
7.0 KiB
Python

"""Refresh token rotation and reuse detection."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.features.auth import repo
from app.features.auth.models import RefreshToken
from app.features.auth.service import AuthError, AuthService
def _refresh_record(
*,
token: str = "old-refresh",
user_id: str = "user-1",
is_revoked: bool = False,
expired: bool = False,
replaced_by_token_id: str | None = None,
rotated_at: datetime | None = None,
) -> MagicMock:
now = datetime.now(timezone.utc)
record = MagicMock(spec=RefreshToken)
record.id = "rt-1"
record.user_id = user_id
record.token = token
record.expires_at = (
now - timedelta(days=1) if expired else now + timedelta(days=30)
)
record.created_at = now
record.is_revoked = is_revoked
record.device_info = "iphone"
record.replaced_by_token_id = replaced_by_token_id
record.rotated_at = rotated_at
return record
def _db_mock() -> MagicMock:
db = MagicMock()
db.commit = AsyncMock()
db.rollback = AsyncMock()
db.refresh = AsyncMock()
db.flush = AsyncMock()
return db
@pytest.mark.asyncio
async def test_refresh_rotates_token_in_transaction(monkeypatch) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
consumed = _refresh_record()
monkeypatch.setattr(
repo,
"try_consume_refresh_token",
AsyncMock(return_value=consumed),
)
monkeypatch.setattr(
repo,
"get_user_by_id",
AsyncMock(return_value=MagicMock(id="user-1")),
)
monkeypatch.setattr(
svc,
"_issue_tokens",
AsyncMock(
return_value={
"access_token": "new-access",
"refresh_token": "new-refresh",
"refresh_token_id": "rt-2",
}
),
)
monkeypatch.setattr(repo, "link_refresh_rotation", AsyncMock())
result = await svc.refresh_tokens("old-refresh")
assert result["access_token"] == "new-access"
assert result["refresh_token"] == "new-refresh"
assert result["refresh_token"] != "old-refresh"
db.commit.assert_awaited_once()
svc._issue_tokens.assert_awaited_once_with("user-1", "iphone")
repo.try_consume_refresh_token.assert_awaited_once_with("old-refresh", db)
repo.link_refresh_rotation.assert_awaited_once()
assert repo.link_refresh_rotation.await_args.args[0] == "rt-1"
assert repo.link_refresh_rotation.await_args.args[1] == "rt-2"
@pytest.mark.asyncio
async def test_refresh_grace_reuse_returns_idempotent_tokens(monkeypatch) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
now = datetime.now(timezone.utc)
revoked = _refresh_record(
is_revoked=True,
replaced_by_token_id="rt-2",
rotated_at=now,
)
replacement = _refresh_record(token="new-refresh", is_revoked=False)
replacement.id = "rt-2"
monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None))
monkeypatch.setattr(
repo,
"get_refresh_token_by_token",
AsyncMock(return_value=revoked),
)
monkeypatch.setattr(
repo,
"get_refresh_token_by_id",
AsyncMock(return_value=replacement),
)
monkeypatch.setattr(
repo,
"get_user_by_id",
AsyncMock(return_value=MagicMock(id="user-1")),
)
monkeypatch.setattr(
svc,
"_revoke_all_active_tokens_in_session",
AsyncMock(return_value=2),
)
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
result = await svc.refresh_tokens("old-refresh")
assert result["refresh_token"] == "new-refresh"
assert result["access_token"]
svc._revoke_all_active_tokens_in_session.assert_not_awaited()
svc._issue_tokens.assert_not_awaited()
db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_refresh_reuse_revokes_all_sessions(monkeypatch) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
now = datetime.now(timezone.utc)
revoked = _refresh_record(
is_revoked=True,
replaced_by_token_id="rt-2",
rotated_at=now - timedelta(seconds=120),
)
monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None))
monkeypatch.setattr(
repo,
"get_refresh_token_by_token",
AsyncMock(return_value=revoked),
)
monkeypatch.setattr(
svc,
"_revoke_all_active_tokens_in_session",
AsyncMock(return_value=2),
)
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
with pytest.raises(AuthError) as exc_info:
await svc.refresh_tokens("old-refresh")
assert exc_info.value.code == "REFRESH_TOKEN_REUSE"
assert exc_info.value.error_code == "REFRESH_TOKEN_REUSE"
svc._revoke_all_active_tokens_in_session.assert_awaited_once_with("user-1")
db.commit.assert_awaited_once()
svc._issue_tokens.assert_not_awaited()
@pytest.mark.asyncio
async def test_refresh_unknown_token(monkeypatch) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None))
monkeypatch.setattr(repo, "get_refresh_token_by_token", AsyncMock(return_value=None))
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
with pytest.raises(AuthError) as exc_info:
await svc.refresh_tokens("missing")
assert exc_info.value.code == "INVALID_TOKEN"
db.rollback.assert_awaited_once()
svc._issue_tokens.assert_not_awaited()
@pytest.mark.asyncio
async def test_refresh_expired_token(monkeypatch) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
expired = _refresh_record(expired=True)
monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None))
monkeypatch.setattr(
repo,
"get_refresh_token_by_token",
AsyncMock(return_value=expired),
)
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
with pytest.raises(AuthError) as exc_info:
await svc.refresh_tokens("old-refresh")
assert exc_info.value.code == "TOKEN_EXPIRED"
db.rollback.assert_awaited_once()
svc._issue_tokens.assert_not_awaited()
@pytest.mark.asyncio
async def test_refresh_user_deleted_rolls_back_consume(monkeypatch) -> None:
db = _db_mock()
sms = MagicMock()
svc = AuthService(db=db, sms=sms)
consumed = _refresh_record()
monkeypatch.setattr(
repo,
"try_consume_refresh_token",
AsyncMock(return_value=consumed),
)
monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=None))
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
with pytest.raises(AuthError) as exc_info:
await svc.refresh_tokens("old-refresh")
assert exc_info.value.code == "USER_NOT_FOUND"
db.rollback.assert_awaited_once()
svc._issue_tokens.assert_not_awaited()