"""transactional / transactional_sync commit and rollback behavior.""" from __future__ import annotations from contextlib import asynccontextmanager, contextmanager from unittest.mock import AsyncMock, MagicMock import pytest from app.core.db import ( transactional, transactional_nested, transactional_nested_sync, transactional_sync, ) @pytest.mark.asyncio async def test_transactional_commits_on_success() -> None: session = MagicMock() session.commit = AsyncMock() session.rollback = AsyncMock() async with transactional(session): pass session.commit.assert_awaited_once() session.rollback.assert_not_awaited() @pytest.mark.asyncio async def test_transactional_rolls_back_on_error() -> None: session = MagicMock() session.commit = AsyncMock() session.rollback = AsyncMock() with pytest.raises(RuntimeError, match="boom"): async with transactional(session): raise RuntimeError("boom") session.commit.assert_not_awaited() session.rollback.assert_awaited_once() def test_transactional_sync_commits_on_success() -> None: session = MagicMock() session.commit = MagicMock() session.rollback = MagicMock() with transactional_sync(session): pass session.commit.assert_called_once() session.rollback.assert_not_called() def test_transactional_sync_rolls_back_on_error() -> None: session = MagicMock() session.commit = MagicMock() session.rollback = MagicMock() with pytest.raises(RuntimeError, match="boom"): with transactional_sync(session): raise RuntimeError("boom") session.commit.assert_not_called() session.rollback.assert_called_once() @pytest.mark.asyncio async def test_transactional_nested_releases_savepoint_on_success() -> None: session = MagicMock() session.commit = AsyncMock() session.rollback = AsyncMock() @asynccontextmanager async def fake_begin_nested(): yield session session.begin_nested = MagicMock(return_value=fake_begin_nested()) async with transactional_nested(session): pass session.begin_nested.assert_called_once() session.commit.assert_not_awaited() session.rollback.assert_not_awaited() @pytest.mark.asyncio async def test_transactional_nested_rolls_back_savepoint_on_error() -> None: session = MagicMock() session.commit = AsyncMock() session.rollback = AsyncMock() @asynccontextmanager async def fake_begin_nested(): yield session session.begin_nested = MagicMock(return_value=fake_begin_nested()) with pytest.raises(RuntimeError, match="boom"): async with transactional_nested(session): raise RuntimeError("boom") session.begin_nested.assert_called_once() session.commit.assert_not_awaited() session.rollback.assert_not_awaited() def test_transactional_nested_sync_releases_savepoint_on_success() -> None: session = MagicMock() session.commit = MagicMock() session.rollback = MagicMock() @contextmanager def fake_begin_nested(): yield session session.begin_nested = MagicMock(return_value=fake_begin_nested()) with transactional_nested_sync(session): pass session.begin_nested.assert_called_once() session.commit.assert_not_called() session.rollback.assert_not_called() def test_transactional_nested_sync_rolls_back_savepoint_on_error() -> None: session = MagicMock() session.commit = MagicMock() session.rollback = MagicMock() @contextmanager def fake_begin_nested(): yield session session.begin_nested = MagicMock(return_value=fake_begin_nested()) with pytest.raises(RuntimeError, match="boom"): with transactional_nested_sync(session): raise RuntimeError("boom") session.begin_nested.assert_called_once() session.commit.assert_not_called() session.rollback.assert_not_called()