配置 SSOT(TOML + .env) 统一错误契约 Auth 与事务边界 Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client 可观测性(OpenTelemetry + LGTM)
235 lines
7.0 KiB
Python
235 lines
7.0 KiB
Python
"""Refresh token rotation and reuse detection."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.features.auth import repo
|
|
from app.features.auth.models import RefreshToken
|
|
from app.features.auth.service import AuthError, AuthService
|
|
|
|
|
|
def _refresh_record(
|
|
*,
|
|
token: str = "old-refresh",
|
|
user_id: str = "user-1",
|
|
is_revoked: bool = False,
|
|
expired: bool = False,
|
|
replaced_by_token_id: str | None = None,
|
|
rotated_at: datetime | None = None,
|
|
) -> MagicMock:
|
|
now = datetime.now(timezone.utc)
|
|
record = MagicMock(spec=RefreshToken)
|
|
record.id = "rt-1"
|
|
record.user_id = user_id
|
|
record.token = token
|
|
record.expires_at = (
|
|
now - timedelta(days=1) if expired else now + timedelta(days=30)
|
|
)
|
|
record.created_at = now
|
|
record.is_revoked = is_revoked
|
|
record.device_info = "iphone"
|
|
record.replaced_by_token_id = replaced_by_token_id
|
|
record.rotated_at = rotated_at
|
|
return record
|
|
|
|
|
|
def _db_mock() -> MagicMock:
|
|
db = MagicMock()
|
|
db.commit = AsyncMock()
|
|
db.rollback = AsyncMock()
|
|
db.refresh = AsyncMock()
|
|
db.flush = AsyncMock()
|
|
return db
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_rotates_token_in_transaction(monkeypatch) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
consumed = _refresh_record()
|
|
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"try_consume_refresh_token",
|
|
AsyncMock(return_value=consumed),
|
|
)
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_user_by_id",
|
|
AsyncMock(return_value=MagicMock(id="user-1")),
|
|
)
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_issue_tokens",
|
|
AsyncMock(
|
|
return_value={
|
|
"access_token": "new-access",
|
|
"refresh_token": "new-refresh",
|
|
"refresh_token_id": "rt-2",
|
|
}
|
|
),
|
|
)
|
|
monkeypatch.setattr(repo, "link_refresh_rotation", AsyncMock())
|
|
|
|
result = await svc.refresh_tokens("old-refresh")
|
|
|
|
assert result["access_token"] == "new-access"
|
|
assert result["refresh_token"] == "new-refresh"
|
|
assert result["refresh_token"] != "old-refresh"
|
|
db.commit.assert_awaited_once()
|
|
svc._issue_tokens.assert_awaited_once_with("user-1", "iphone")
|
|
repo.try_consume_refresh_token.assert_awaited_once_with("old-refresh", db)
|
|
repo.link_refresh_rotation.assert_awaited_once()
|
|
assert repo.link_refresh_rotation.await_args.args[0] == "rt-1"
|
|
assert repo.link_refresh_rotation.await_args.args[1] == "rt-2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_grace_reuse_returns_idempotent_tokens(monkeypatch) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
now = datetime.now(timezone.utc)
|
|
revoked = _refresh_record(
|
|
is_revoked=True,
|
|
replaced_by_token_id="rt-2",
|
|
rotated_at=now,
|
|
)
|
|
replacement = _refresh_record(token="new-refresh", is_revoked=False)
|
|
replacement.id = "rt-2"
|
|
|
|
monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_refresh_token_by_token",
|
|
AsyncMock(return_value=revoked),
|
|
)
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_refresh_token_by_id",
|
|
AsyncMock(return_value=replacement),
|
|
)
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_user_by_id",
|
|
AsyncMock(return_value=MagicMock(id="user-1")),
|
|
)
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_revoke_all_active_tokens_in_session",
|
|
AsyncMock(return_value=2),
|
|
)
|
|
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
|
|
|
|
result = await svc.refresh_tokens("old-refresh")
|
|
|
|
assert result["refresh_token"] == "new-refresh"
|
|
assert result["access_token"]
|
|
svc._revoke_all_active_tokens_in_session.assert_not_awaited()
|
|
svc._issue_tokens.assert_not_awaited()
|
|
db.commit.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_reuse_revokes_all_sessions(monkeypatch) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
now = datetime.now(timezone.utc)
|
|
revoked = _refresh_record(
|
|
is_revoked=True,
|
|
replaced_by_token_id="rt-2",
|
|
rotated_at=now - timedelta(seconds=120),
|
|
)
|
|
|
|
monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_refresh_token_by_token",
|
|
AsyncMock(return_value=revoked),
|
|
)
|
|
monkeypatch.setattr(
|
|
svc,
|
|
"_revoke_all_active_tokens_in_session",
|
|
AsyncMock(return_value=2),
|
|
)
|
|
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.refresh_tokens("old-refresh")
|
|
|
|
assert exc_info.value.code == "REFRESH_TOKEN_REUSE"
|
|
assert exc_info.value.error_code == "REFRESH_TOKEN_REUSE"
|
|
svc._revoke_all_active_tokens_in_session.assert_awaited_once_with("user-1")
|
|
db.commit.assert_awaited_once()
|
|
svc._issue_tokens.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_unknown_token(monkeypatch) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
|
|
monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(repo, "get_refresh_token_by_token", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.refresh_tokens("missing")
|
|
|
|
assert exc_info.value.code == "INVALID_TOKEN"
|
|
db.rollback.assert_awaited_once()
|
|
svc._issue_tokens.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_expired_token(monkeypatch) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
expired = _refresh_record(expired=True)
|
|
|
|
monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"get_refresh_token_by_token",
|
|
AsyncMock(return_value=expired),
|
|
)
|
|
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.refresh_tokens("old-refresh")
|
|
|
|
assert exc_info.value.code == "TOKEN_EXPIRED"
|
|
db.rollback.assert_awaited_once()
|
|
svc._issue_tokens.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_user_deleted_rolls_back_consume(monkeypatch) -> None:
|
|
db = _db_mock()
|
|
sms = MagicMock()
|
|
svc = AuthService(db=db, sms=sms)
|
|
consumed = _refresh_record()
|
|
|
|
monkeypatch.setattr(
|
|
repo,
|
|
"try_consume_refresh_token",
|
|
AsyncMock(return_value=consumed),
|
|
)
|
|
monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=None))
|
|
monkeypatch.setattr(svc, "_issue_tokens", AsyncMock())
|
|
|
|
with pytest.raises(AuthError) as exc_info:
|
|
await svc.refresh_tokens("old-refresh")
|
|
|
|
assert exc_info.value.code == "USER_NOT_FOUND"
|
|
db.rollback.assert_awaited_once()
|
|
svc._issue_tokens.assert_not_awaited()
|