Files
life-echo/api/tests/test_memoir_skip_story.py

187 lines
5.9 KiB
Python
Raw Normal View History

"""回忆录segment_skip_story_ids 与 batch 级短路条件orchestrator 侧)。"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from app.agents.memoir.batch_phase1_prep import BatchPhase1SegmentRow
from app.agents.memoir.classification_agent import (
ChapterClassifyResult,
ClassificationAgent,
)
from app.agents.memoir.extraction_agent import ExtractionResult
from app.agents.memoir.orchestrator import MemoirOrchestrator
from app.agents.stage_constants import CHAT_STAGES
from app.agents.state_schema import DEFAULT_STAGE_ORDER, MemoirStateSchema
from app.features.memoir.constants import memoir
def _empty_state() -> MemoirStateSchema:
"""与生产默认一致的五阶段 stage_order计划 §5-C 全量阶段管道覆盖)。"""
assert list(CHAT_STAGES) == DEFAULT_STAGE_ORDER
return MemoirStateSchema(
stage_order=list(CHAT_STAGES),
current_stage="childhood",
covered_stages=[],
slots={},
)
def test_prepare_batches_skip_story_id_when_llm_none_and_empty_slots() -> None:
orch = MemoirOrchestrator()
orch.extraction_agent.extract = MagicMock(
return_value=ExtractionResult(detected_stage="career", slots={})
)
orch.classification_agent.classify = MagicMock(
return_value=ChapterClassifyResult(category="summary", llm_said_none=True)
)
st = _empty_state()
def get_state() -> MemoirStateSchema:
return st
def update_slot(
stage: str, slot_name: str, snippet: str, seg_ids: list[str]
) -> MemoirStateSchema:
return st
seg = SimpleNamespace(id="seg-skip-1", user_input_text="聊聊别的吧")
p = orch.prepare_batches(
segments=[seg],
llm=MagicMock(),
get_or_create_state=get_state,
update_slot=update_slot,
)
assert "seg-skip-1" in p.segment_skip_story_ids
def test_prepare_batches_fragment_heuristic_not_in_skip_set() -> None:
"""fragment-only→summary 且 llm_said_none=False不进入 skip 集合。"""
orch = MemoirOrchestrator()
orch.extraction_agent.extract = MagicMock(
return_value=ExtractionResult(detected_stage="career", slots={})
)
orch.classification_agent = ClassificationAgent()
st = _empty_state()
def get_state() -> MemoirStateSchema:
return st
def update_slot(
stage: str, slot_name: str, snippet: str, seg_ids: list[str]
) -> MemoirStateSchema:
return st
seg = SimpleNamespace(id="seg-frag-1", user_input_text="1999年出生")
p = orch.prepare_batches(
segments=[seg],
llm=None,
get_or_create_state=get_state,
update_slot=update_slot,
)
assert "seg-frag-1" not in p.segment_skip_story_ids
def test_prepare_batches_mixed_batch_only_one_segment_in_skip_set() -> None:
"""同 category 两段:仅一段满足 skip 条件 → skip 集合仅含该段 id。"""
orch = MemoirOrchestrator()
orch.extraction_agent.extract = MagicMock(
side_effect=[
ExtractionResult(detected_stage="career", slots={}),
ExtractionResult(detected_stage="career", slots={"job": "戏剧演员"}),
]
)
orch.classification_agent.classify = MagicMock(
side_effect=[
ChapterClassifyResult(category="summary", llm_said_none=True),
ChapterClassifyResult(category="summary", llm_said_none=False),
]
)
st = _empty_state()
def get_state() -> MemoirStateSchema:
return st
def update_slot(
stage: str, slot_name: str, snippet: str, seg_ids: list[str]
) -> MemoirStateSchema:
return st
s1 = SimpleNamespace(id="mix-1", user_input_text="聊聊别的吧")
s2 = SimpleNamespace(id="mix-2", user_input_text="后来当了演员")
p = orch.prepare_batches(
segments=[s1, s2],
llm=MagicMock(),
get_or_create_state=get_state,
update_slot=update_slot,
)
assert p.segment_skip_story_ids == {"mix-1"}
assert len(p.category_to_segments.get("summary", [])) == 2
def test_prepare_batches_batch_llm_path_matches_per_segment_skip_logic(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
"app.agents.memoir.orchestrator.memoir.phase1_batch_llm_enabled",
True,
)
def fake_batch(
segments: list,
state: MemoirStateSchema,
llm: object,
*,
chunk_size: int = 24,
2026-04-10 16:09:44 +08:00
on_chunk=None,
language: str = "zh",
) -> dict:
return {
"mix-1": BatchPhase1SegmentRow(
detected_stage="career",
slots={},
chapter_category_raw="none",
),
"mix-2": BatchPhase1SegmentRow(
detected_stage="career",
slots={"job": "戏剧演员"},
chapter_category_raw="summary",
),
}
monkeypatch.setattr(
"app.agents.memoir.orchestrator.run_batch_phase1_prep_chunked",
fake_batch,
)
orch = MemoirOrchestrator()
st = _empty_state()
def get_state() -> MemoirStateSchema:
return st
def update_slot(
stage: str, slot_name: str, snippet: str, seg_ids: list[str]
) -> MemoirStateSchema:
return st
s1 = SimpleNamespace(id="mix-1", user_input_text="聊聊别的吧")
s2 = SimpleNamespace(id="mix-2", user_input_text="后来当了演员")
p = orch.prepare_batches(
segments=[s1, s2],
llm=MagicMock(),
get_or_create_state=get_state,
update_slot=update_slot,
)
assert p.segment_skip_story_ids == {"mix-1"}
assert len(p.category_to_segments.get("summary", [])) == 2
def test_batch_all_skip_predicate() -> None:
"""memoir_tasks 短路条件batch_ids <= skip_ids。"""
batch_ids = {"a", "b"}
skip_ids = {"a"}
assert not (batch_ids <= skip_ids)
assert {"a"} <= {"a", "b"}
assert {"a", "b"} <= {"a", "b"}