"""ConversationHistoryStore transactional boundaries.""" from __future__ import annotations from contextlib import asynccontextmanager from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy.ext.asyncio import AsyncSession from app.features.conversation.history_store import ConversationHistoryStore @asynccontextmanager async def _capture_transactional(db): yield db await db.commit() @pytest.mark.asyncio async def test_record_ai_only_turn_commits_before_redis_sync() -> None: db = MagicMock(spec=AsyncSession) db.commit = AsyncMock() redis_sync = AsyncMock() captured: list[object] = [] class FakeMsg: def __init__(self, **kwargs) -> None: self.id = "ai-1" for k, v in kwargs.items(): setattr(self, k, v) class _FakeRepo: @staticmethod def add_conversation_message(msg: object, _db) -> None: captured.append(msg) with patch( "app.features.conversation.history_store.transactional", _capture_transactional, ), patch( "app.features.conversation.history_store.ConversationMessage", FakeMsg, ), patch( "app.features.conversation.history_store.repo", _FakeRepo, ): store = ConversationHistoryStore(db) store._sync_redis_best_effort = redis_sync # type: ignore[method-assign] store._touch_conversation = AsyncMock() # type: ignore[method-assign] msg_id = await store.record_ai_only_turn("conv-1", ["hello"]) assert msg_id is not None assert len(captured) == 1 assert captured[0].id == msg_id db.commit.assert_awaited_once() redis_sync.assert_awaited_once_with("conv-1") @pytest.mark.asyncio async def test_attach_ai_tts_commits_repo_update_before_redis_sync() -> None: db = MagicMock(spec=AsyncSession) db.commit = AsyncMock() redis_sync = AsyncMock() repo_calls: list[tuple] = [] async def fake_set_latest(*args, **kwargs): repo_calls.append((args, kwargs)) return object() with patch( "app.features.conversation.history_store.transactional", _capture_transactional, ), patch( "app.features.conversation.history_store.repo.set_latest_ai_message_tts_audio_urls", fake_set_latest, ): store = ConversationHistoryStore(db) store._sync_redis_best_effort = redis_sync # type: ignore[method-assign] await store.attach_ai_tts_audio_urls( "conv-1", tts_audio_urls=["https://example.com/a.mp3"], segment_id="seg-1", ) assert len(repo_calls) == 1 db.commit.assert_awaited_once() redis_sync.assert_awaited_once_with("conv-1") @pytest.mark.asyncio async def test_record_human_ai_turn_commits_pair_before_redis_sync() -> None: db = MagicMock(spec=AsyncSession) db.commit = AsyncMock() redis_sync = AsyncMock() captured: list[object] = [] class FakeMsg: def __init__(self, **kwargs) -> None: self.id = kwargs.get("id") or f"msg-{len(captured)}" for k, v in kwargs.items(): setattr(self, k, v) class _FakeRepo: @staticmethod def add_conversation_message(msg: object, _db) -> None: captured.append(msg) with patch( "app.features.conversation.history_store.transactional", _capture_transactional, ), patch( "app.features.conversation.history_store.ConversationMessage", FakeMsg, ), patch( "app.features.conversation.history_store.repo", _FakeRepo, ): store = ConversationHistoryStore(db) store._sync_redis_best_effort = redis_sync # type: ignore[method-assign] store._touch_conversation = AsyncMock() # type: ignore[method-assign] out = await store.record_human_ai_turn( "conv-1", "hello", ["reply"], user_message_timestamp=datetime.now(timezone.utc), is_from_voice=False, voice_session_id=None, audio_duration_seconds=None, tts_audio_urls=None, segment_id="seg-1", ) assert out is not None assert len(captured) == 2 db.commit.assert_awaited_once() redis_sync.assert_awaited_once_with("conv-1")