"""WS pipeline / history_store 原子持久化边界;memoir 调度顺序。""" from __future__ import annotations from contextlib import asynccontextmanager from datetime import datetime, timezone from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import uuid import pytest from sqlalchemy.ext.asyncio import AsyncSession from app.features.auth import models as _auth_models # noqa: F401 from app.features.conversation import models as _conv_models # noqa: F401 from app.features.memory import models as _memory_models # noqa: F401 from app.features.memoir import models as _memoir_models # noqa: F401 from app.features.payment import models as _payment_models # noqa: F401 from app.features.story import models as _story_models # noqa: F401 from app.features.user import models as _user_models # noqa: F401 from app.features.conversation.history_store import ConversationHistoryStore from app.features.conversation.models import ConversationMessage, Segment from app.features.conversation.ws import persist @asynccontextmanager async def _capture_transactional(db): yield db await db.commit() @pytest.mark.asyncio async def test_persist_message_tts_url_segment_commits_once() -> None: db = MagicMock(spec=AsyncSession) db.commit = AsyncMock() msg = ConversationMessage( id="msg-1", conversation_id="conv-1", role="ai", content="hi", tts_audio_urls=[], ) with patch("app.features.conversation.ws.persist.transactional", _capture_transactional): await persist.persist_message_tts_url_segment(db, msg, 0, "https://cos/0.mp3") assert msg.tts_audio_urls == ["https://cos/0.mp3"] db.commit.assert_awaited_once() @pytest.mark.asyncio async def test_persist_voice_segment_row_commits_segment_and_activity() -> None: from app.features.conversation.models import Conversation db = MagicMock(spec=AsyncSession) db.commit = AsyncMock() db.add = MagicMock() conv = Conversation(id="conv-1", user_id="u1") segment = Segment( id="seg-1", conversation_id="conv-1", user_input_text="hello", processed=False, ) with patch("app.features.conversation.ws.persist.transactional", _capture_transactional): await persist.persist_voice_segment_row(db, segment, conv) db.add.assert_called_once_with(segment) assert conv.last_message_at is not None db.commit.assert_awaited_once() @pytest.mark.asyncio async def test_record_human_ai_turn_with_segment_single_commit() -> None: db = MagicMock(spec=AsyncSession) db.commit = AsyncMock() db.rollback = AsyncMock() db.flush = AsyncMock() commit_count = 0 @asynccontextmanager async def counting_transactional(session): nonlocal commit_count yield session await session.commit() commit_count += 1 store = ConversationHistoryStore(db) segment = Segment( id="seg-1", conversation_id="conv-1", user_input_text="hi", processed=False, ) with ( patch( "app.features.conversation.history_store.transactional", counting_transactional, ), patch.object(store, "_touch_conversation", AsyncMock()), patch.object(store, "_sync_redis_best_effort", AsyncMock()), patch("app.features.conversation.history_store.repo.add_conversation_message"), ): turn_ids = await store.record_human_ai_turn_with_segment( "conv-1", "hello", ["reply"], segment, user_message_timestamp=datetime.now(timezone.utc), is_from_voice=True, voice_session_id="vs-1", audio_duration_seconds=3, agent_response="reply", ) assert turn_ids is not None assert commit_count == 1 db.flush.assert_awaited_once() assert segment.agent_response == "reply" assert segment.user_message_id == turn_ids.human_message_id assert segment.lineage_json is not None assert segment.lineage_json["primary_user_message_id"] == turn_ids.human_message_id assert segment.lineage_json["turns"][0]["user_message_id"] == turn_ids.human_message_id @pytest.mark.asyncio async def test_record_human_ai_turn_with_segment_postgres_flush_order() -> None: """Regression: Postgres FK on segments.user_message_id requires message INSERT first.""" from app.core.config import settings if not settings.database_url.startswith("postgresql"): pytest.skip("requires PostgreSQL") from sqlalchemy import text from app.core.db import AsyncSessionLocal, transactional from app.features.conversation.models import Conversation from app.features.user.models import User try: async with AsyncSessionLocal() as db: await db.execute(text("SELECT 1")) except Exception: pytest.skip("PostgreSQL not reachable") uid = str(uuid.uuid4()) cid = str(uuid.uuid4()) sid = str(uuid.uuid4()) now = datetime.now(timezone.utc) async with AsyncSessionLocal() as db: db.add( User( id=uid, phone=f"138{uuid.uuid4().int % 100_000_000:08d}", password_hash="x", nickname="t", subscription_type="free", created_at=now, ) ) conv = Conversation(id=cid, user_id=uid, last_message_at=now) db.add(conv) segment = Segment( id=sid, conversation_id=cid, user_input_text="hi", processed=False, created_at=now, ) async with transactional(db): db.add(segment) await db.refresh(segment) async with AsyncSessionLocal() as db: segment = await db.get(Segment, sid) store = ConversationHistoryStore(db) turn_ids = await store.record_human_ai_turn_with_segment( cid, "hello", ["reply"], segment, user_message_timestamp=now, is_from_voice=False, voice_session_id=None, audio_duration_seconds=None, agent_response="reply", ) assert turn_ids is not None assert segment.user_message_id == turn_ids.human_message_id @pytest.mark.asyncio async def test_attach_ai_tts_for_turn_single_commit() -> None: db = MagicMock(spec=AsyncSession) commit_count = 0 @asynccontextmanager async def counting_transactional(session): nonlocal commit_count yield session await session.commit() commit_count += 1 store = ConversationHistoryStore(db) segment = Segment( id="seg-1", conversation_id="conv-1", user_input_text="hi", processed=False, ) ai_row = ConversationMessage( id="ai-1", conversation_id="conv-1", role="ai", content="reply", ) async def _set_tts(*_args, **_kwargs): ai_row.tts_audio_urls = ["https://cos/a.mp3"] return ai_row with ( patch( "app.features.conversation.history_store.transactional", counting_transactional, ), patch.object(store, "_sync_redis_best_effort", AsyncMock()), patch( "app.features.conversation.history_store.repo.set_latest_ai_message_tts_audio_urls", _set_tts, ), ): await store.attach_ai_tts_for_turn( "conv-1", tts_audio_urls=["https://cos/a.mp3"], segment=segment, ) assert commit_count == 1 assert segment.tts_audio_urls == ["https://cos/a.mp3"] assert ai_row.tts_audio_urls == ["https://cos/a.mp3"]