Files
life-echo/api/tests/test_memoir_skip_story.py

123 lines
3.9 KiB
Python
Raw Normal View History

"""回忆录segment_skip_story_ids 与 batch 级短路条件orchestrator 侧)。"""
from types import SimpleNamespace
from unittest.mock import MagicMock
from app.agents.memoir.classification_agent import (
ChapterClassifyResult,
ClassificationAgent,
)
from app.agents.memoir.extraction_agent import ExtractionResult
from app.agents.memoir.orchestrator import MemoirOrchestrator
from app.agents.state_schema import MemoirStateSchema
def _empty_state() -> MemoirStateSchema:
return MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
)
def test_prepare_batches_skip_story_id_when_llm_none_and_empty_slots() -> None:
orch = MemoirOrchestrator()
orch.extraction_agent.extract = MagicMock(
return_value=ExtractionResult(detected_stage="career", slots={})
)
orch.classification_agent.classify = MagicMock(
return_value=ChapterClassifyResult(category="summary", llm_said_none=True)
)
st = _empty_state()
def get_state() -> MemoirStateSchema:
return st
def update_slot(
stage: str, slot_name: str, snippet: str, seg_ids: list[str]
) -> MemoirStateSchema:
return st
seg = SimpleNamespace(id="seg-skip-1", user_input_text="聊聊别的吧")
p = orch.prepare_batches(
segments=[seg],
llm=MagicMock(),
get_or_create_state=get_state,
update_slot=update_slot,
)
assert "seg-skip-1" in p.segment_skip_story_ids
def test_prepare_batches_fragment_heuristic_not_in_skip_set() -> None:
"""fragment-only→summary 且 llm_said_none=False不进入 skip 集合。"""
orch = MemoirOrchestrator()
orch.extraction_agent.extract = MagicMock(
return_value=ExtractionResult(detected_stage="career", slots={})
)
orch.classification_agent = ClassificationAgent()
st = _empty_state()
def get_state() -> MemoirStateSchema:
return st
def update_slot(
stage: str, slot_name: str, snippet: str, seg_ids: list[str]
) -> MemoirStateSchema:
return st
seg = SimpleNamespace(id="seg-frag-1", user_input_text="1999年出生")
p = orch.prepare_batches(
segments=[seg],
llm=None,
get_or_create_state=get_state,
update_slot=update_slot,
)
assert "seg-frag-1" not in p.segment_skip_story_ids
def test_prepare_batches_mixed_batch_only_one_segment_in_skip_set() -> None:
"""同 category 两段:仅一段满足 skip 条件 → skip 集合仅含该段 id。"""
orch = MemoirOrchestrator()
orch.extraction_agent.extract = MagicMock(
side_effect=[
ExtractionResult(detected_stage="career", slots={}),
ExtractionResult(detected_stage="career", slots={"job": "戏剧演员"}),
]
)
orch.classification_agent.classify = MagicMock(
side_effect=[
ChapterClassifyResult(category="summary", llm_said_none=True),
ChapterClassifyResult(category="summary", llm_said_none=False),
]
)
st = _empty_state()
def get_state() -> MemoirStateSchema:
return st
def update_slot(
stage: str, slot_name: str, snippet: str, seg_ids: list[str]
) -> MemoirStateSchema:
return st
s1 = SimpleNamespace(id="mix-1", user_input_text="聊聊别的吧")
s2 = SimpleNamespace(id="mix-2", user_input_text="后来当了演员")
p = orch.prepare_batches(
segments=[s1, s2],
llm=MagicMock(),
get_or_create_state=get_state,
update_slot=update_slot,
)
assert p.segment_skip_story_ids == {"mix-1"}
assert len(p.category_to_segments.get("summary", [])) == 2
def test_batch_all_skip_predicate() -> None:
"""memoir_tasks 短路条件batch_ids <= skip_ids。"""
batch_ids = {"a", "b"}
skip_ids = {"a"}
assert not (batch_ids <= skip_ids)
assert {"a"} <= {"a", "b"}
assert {"a", "b"} <= {"a", "b"}