"""Memoir queue scheduling and phase1 memory ingest idempotency.""" from __future__ import annotations from datetime import datetime, timezone from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy.ext.asyncio import AsyncSession from app.features.conversation.models import Conversation, Segment from app.features.user.models import User from app.features.conversation.ws import pipeline as ws_pipeline from app.features.memory.ingest_service import MemoryIngestService from app.tasks import memoir_tasks from sqlalchemy.exc import IntegrityError @pytest.mark.asyncio async def test_process_user_message_queues_memoir_after_lineage() -> None: db = MagicMock(spec=AsyncSession) user = User(id="u1", language_preference="zh") conversation = Conversation(id="conv-1", user_id="u1") segment = Segment( id="seg-1", conversation_id="conv-1", user_input_text="hello", processed=False, ) call_order: list[str] = [] async def _record_turn(*_args, **_kwargs): call_order.append("record_turn") segment.lineage_json = {"primary_user_message_id": "hum-1"} return SimpleNamespace( human_message_id="hum-1", assistant_message_id="ai-1", ) async def _queue_segment(*_args, **_kwargs): call_order.append("queue_segment") turn = SimpleNamespace( messages=["reply"], skip_tts=True, memory_retrieval_trace=None, ) with ( patch.object( ws_pipeline.chat_turn_service, "process_turn", AsyncMock(return_value=turn), ), patch.object( ws_pipeline.ConversationHistoryStore, "record_human_ai_turn_with_segment", AsyncMock(side_effect=_record_turn), ), patch.object( ws_pipeline, "_schedule_memoir_ingest_for_segment", AsyncMock(side_effect=_queue_segment), ), patch.object(ws_pipeline.manager, "active_connections", {}), patch.object(ws_pipeline.manager, "send_message", AsyncMock()), patch.object( ws_pipeline, "get_or_create_state", AsyncMock(return_value=SimpleNamespace(current_stage="childhood")), ), patch.object(ws_pipeline, "maybe_send_topic_chips_ws", AsyncMock()), ): await ws_pipeline.process_user_message( "conv-1", "hello", conversation, segment, db, user=user, user_message_timestamp=datetime.now(timezone.utc), ) assert call_order == ["record_turn", "queue_segment"] @pytest.mark.asyncio async def test_process_user_message_queues_memoir_on_ai_failure() -> None: db = MagicMock(spec=AsyncSession) user = User(id="u1", language_preference="zh") conversation = Conversation(id="conv-1", user_id="u1") segment = Segment( id="seg-1", conversation_id="conv-1", user_input_text="hello", processed=False, ) queued = False async def _queue_segment(*_args, **_kwargs): nonlocal queued queued = True turn = SimpleNamespace(messages=[], skip_tts=True, memory_retrieval_trace=None) with ( patch.object( ws_pipeline.chat_turn_service, "process_turn", AsyncMock(return_value=turn), ), patch.object( ws_pipeline.ConversationHistoryStore, "record_human_ai_turn_with_segment", AsyncMock(return_value=None), ), patch.object( ws_pipeline, "_schedule_memoir_ingest_for_segment", AsyncMock(side_effect=_queue_segment), ), patch.object(ws_pipeline.manager, "active_connections", {"conv-1": MagicMock()}), patch.object(ws_pipeline.manager, "send_message", AsyncMock()), ): await ws_pipeline.process_user_message( "conv-1", "hello", conversation, segment, db, user=user, user_message_timestamp=datetime.now(timezone.utc), ) assert queued is True @pytest.mark.asyncio async def test_ingest_batch_idempotent_by_segment_id(monkeypatch) -> None: stored: dict[str, SimpleNamespace] = {} create_calls = 0 class FakeSession: async def commit(self) -> None: pass async def flush(self) -> None: pass async def fake_get(db, *, user_id: str, segment_id: str): return stored.get(segment_id) async def fake_create_source(session, **kwargs): nonlocal create_calls create_calls += 1 sid = kwargs["segment_id"] src = SimpleNamespace(id=f"src-{create_calls}", segment_id=sid) stored[sid] = src return src async def fake_create_chunk(*_args, **kwargs): return SimpleNamespace(id=f"ch-{kwargs.get('chunk_index')}") class FakeEmbeddingService: def __init__(self, *_args, **_kwargs) -> None: pass async def embed_source(self, user_id: str, source_id: str) -> dict: return {"status": "success", "vectors_written": 1} monkeypatch.setattr( "app.features.memory.ingest_service.get_transcript_source_by_segment_id", fake_get, ) monkeypatch.setattr( "app.features.memory.ingest_service.create_source", fake_create_source, ) monkeypatch.setattr( "app.features.memory.ingest_service.create_chunk", fake_create_chunk, ) monkeypatch.setattr( "app.features.memory.ingest_service.MemoryEmbeddingService", FakeEmbeddingService, ) monkeypatch.setattr("app.features.memory.constants.memory.enrichment_enabled", False) service = MemoryIngestService( FakeSession(), # type: ignore[arg-type] embedding_provider=None, enrichment_scheduler=MagicMock(schedule_many=MagicMock(return_value=[])), ) items = [ ("c1", "hello", {"primary_user_message_id": "m1"}, "seg-1"), ] first = await service.ingest_transcripts_batch("u1", items) second = await service.ingest_transcripts_batch("u1", items) assert first == ["src-1"] assert second == ["src-1"] assert create_calls == 1 def test_phase1_memory_ingest_batch_sync_reraises_on_failure(monkeypatch) -> None: async def _fail(*_args, **_kwargs): raise RuntimeError("ingest unavailable") monkeypatch.setattr(memoir_tasks, "_memory_ingest_transcripts_batch", _fail) db = MagicMock() items = [("c1", "hello", None, "seg-1")] with pytest.raises(RuntimeError, match="ingest unavailable"): memoir_tasks._phase1_memory_ingest_batch_sync( db, "u1", items, memoir_correlation_id="corr-1", ) def test_phase1_memory_ingest_batch_sync_resolves_integrity_race(monkeypatch) -> None: async def _race(*_args, **_kwargs): raise IntegrityError("insert", {}, Exception("unique")) monkeypatch.setattr(memoir_tasks, "_memory_ingest_transcripts_batch", _race) db = MagicMock() existing = SimpleNamespace(id="src-existing") def _lookup(_db, *, user_id: str, segment_id: str): assert user_id == "u1" assert segment_id == "seg-1" return existing monkeypatch.setattr( memoir_tasks, "get_transcript_source_by_segment_id_sync", _lookup, ) items = [("c1", "hello", None, "seg-1")] ids = memoir_tasks._phase1_memory_ingest_batch_sync( db, "u1", items, memoir_correlation_id="corr-1", ) assert ids == ["src-existing"]