"""回忆录两阶段管线:Phase2 触发条件与 orchestrator 字段。""" from unittest.mock import MagicMock import pytest from app.agents.memoir.orchestrator import MemoirOrchestrator from app.agents.memoir.extraction_agent import ExtractionResult from app.agents.memoir.classification_agent import ChapterClassifyResult from app.agents.state_schema import MemoirStateSchema from app.tasks.memoir_tasks import _should_trigger_phase2 from app.features.memoir.constants import memoir def test_segment_chapter_category_populated() -> None: orch = MemoirOrchestrator() orch.extraction_agent.extract = MagicMock( return_value=ExtractionResult( detected_stage="childhood", slots={"daily_life": "玩布娃娃"} ) ) orch.classification_agent.classify = MagicMock( return_value=ChapterClassifyResult(category="childhood", llm_said_none=False) ) st = MemoirStateSchema( stage_order=["childhood"], current_stage="childhood", covered_stages=[], slots={}, ) def get_state() -> MemoirStateSchema: return st def update_slot( stage: str, slot_name: str, snippet: str, seg_ids: list[str] ) -> MemoirStateSchema: return st class _Seg: def __init__(self, sid: str, text: str) -> None: self.id = sid self.user_input_text = text s1 = _Seg("a1", "小时候喜欢玩布娃娃") p = orch.prepare_batches( segments=[s1], llm=MagicMock(), get_or_create_state=get_state, update_slot=update_slot, ) assert p.segment_chapter_category["a1"] == "childhood" @pytest.mark.parametrize( "count,total_chars,current_chars,expect", [ (1, 10, 60, True), # immediate via current segment chars (3, 5, 5, True), # batch min segments (2, 100, 5, True), # batch min total chars (2, 50, 5, False), # below both accum thresholds ], ) def test_should_trigger_phase2_matrix( monkeypatch: pytest.MonkeyPatch, count: int, total_chars: int, current_chars: int, expect: bool, ) -> None: monkeypatch.setattr( "app.tasks.memoir_tasks.memoir.narrative_immediate_char_threshold", 50, ) monkeypatch.setattr( "app.tasks.memoir_tasks.memoir.narrative_batch_min_segments", 3, ) monkeypatch.setattr( "app.tasks.memoir_tasks.memoir.narrative_batch_min_chars", 80, ) db = MagicMock() db.execute.return_value.one.return_value = (count, total_chars) assert _should_trigger_phase2(db, "user-1", "childhood", current_chars) == expect