"""回忆录:segment_skip_story_ids 与 batch 级短路条件(orchestrator 侧)。""" from types import SimpleNamespace from unittest.mock import MagicMock from app.agents.memoir.classification_agent import ( ChapterClassifyResult, ClassificationAgent, ) from app.agents.memoir.extraction_agent import ExtractionResult from app.agents.memoir.orchestrator import MemoirOrchestrator from app.agents.stage_constants import CHAT_STAGES from app.agents.state_schema import DEFAULT_STAGE_ORDER, MemoirStateSchema def _empty_state() -> MemoirStateSchema: """与生产默认一致的五阶段 stage_order(计划 §5-C 全量阶段管道覆盖)。""" assert list(CHAT_STAGES) == DEFAULT_STAGE_ORDER return MemoirStateSchema( stage_order=list(CHAT_STAGES), current_stage="childhood", covered_stages=[], slots={}, ) def test_prepare_batches_skip_story_id_when_llm_none_and_empty_slots() -> None: orch = MemoirOrchestrator() orch.extraction_agent.extract = MagicMock( return_value=ExtractionResult(detected_stage="career", slots={}) ) orch.classification_agent.classify = MagicMock( return_value=ChapterClassifyResult(category="summary", llm_said_none=True) ) st = _empty_state() def get_state() -> MemoirStateSchema: return st def update_slot( stage: str, slot_name: str, snippet: str, seg_ids: list[str] ) -> MemoirStateSchema: return st seg = SimpleNamespace(id="seg-skip-1", user_input_text="聊聊别的吧") p = orch.prepare_batches( segments=[seg], llm=MagicMock(), get_or_create_state=get_state, update_slot=update_slot, ) assert "seg-skip-1" in p.segment_skip_story_ids def test_prepare_batches_fragment_heuristic_not_in_skip_set() -> None: """fragment-only→summary 且 llm_said_none=False,不进入 skip 集合。""" orch = MemoirOrchestrator() orch.extraction_agent.extract = MagicMock( return_value=ExtractionResult(detected_stage="career", slots={}) ) orch.classification_agent = ClassificationAgent() st = _empty_state() def get_state() -> MemoirStateSchema: return st def update_slot( stage: str, slot_name: str, snippet: str, seg_ids: list[str] ) -> MemoirStateSchema: return st seg = SimpleNamespace(id="seg-frag-1", user_input_text="1999年出生") p = orch.prepare_batches( segments=[seg], llm=None, get_or_create_state=get_state, update_slot=update_slot, ) assert "seg-frag-1" not in p.segment_skip_story_ids def test_prepare_batches_mixed_batch_only_one_segment_in_skip_set() -> None: """同 category 两段:仅一段满足 skip 条件 → skip 集合仅含该段 id。""" orch = MemoirOrchestrator() orch.extraction_agent.extract = MagicMock( side_effect=[ ExtractionResult(detected_stage="career", slots={}), ExtractionResult(detected_stage="career", slots={"job": "戏剧演员"}), ] ) orch.classification_agent.classify = MagicMock( side_effect=[ ChapterClassifyResult(category="summary", llm_said_none=True), ChapterClassifyResult(category="summary", llm_said_none=False), ] ) st = _empty_state() def get_state() -> MemoirStateSchema: return st def update_slot( stage: str, slot_name: str, snippet: str, seg_ids: list[str] ) -> MemoirStateSchema: return st s1 = SimpleNamespace(id="mix-1", user_input_text="聊聊别的吧") s2 = SimpleNamespace(id="mix-2", user_input_text="后来当了演员") p = orch.prepare_batches( segments=[s1, s2], llm=MagicMock(), get_or_create_state=get_state, update_slot=update_slot, ) assert p.segment_skip_story_ids == {"mix-1"} assert len(p.category_to_segments.get("summary", [])) == 2 def test_batch_all_skip_predicate() -> None: """memoir_tasks 短路条件:batch_ids <= skip_ids。""" batch_ids = {"a", "b"} skip_ids = {"a"} assert not (batch_ids <= skip_ids) assert {"a"} <= {"a", "b"} assert {"a", "b"} <= {"a", "b"}