Files
life-echo/api/tests/test_memoir_pipeline_optimization.py

330 lines
10 KiB
Python
Raw Normal View History

"""Validation tests for memoir pipeline optimization (Phase A/B/C).
Tests:
- Phase1 batch path is now the default
- Memory enrichment is dispatched asynchronously
- Unified narrative unit executor produces correct results
- Post-commit fan-out includes quality pass
- Quality pass task handles title polishing
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from app.agents.memoir.classification_agent import ChapterClassifyResult
from app.agents.memoir.extraction_agent import ExtractionResult
from app.agents.memoir.orchestrator import MemoirOrchestrator
from app.agents.state_schema import MemoirStateSchema
from app.features.memoir.constants import memoir
# ---------------------------------------------------------------------------
# Phase1 batch path defaults
# ---------------------------------------------------------------------------
def test_phase1_batch_enabled_by_default() -> None:
"""memoir_phase1_batch_llm_enabled should default to True after optimization."""
assert memoir.phase1_batch_llm_enabled is True
assert memoir.phase1_batch_llm_chunk_size >= 1
def test_quality_pass_enabled_by_default() -> None:
assert memoir.quality_pass_enabled is True
# ---------------------------------------------------------------------------
# Phase1 orchestrator selects batch path when available
# ---------------------------------------------------------------------------
def test_orchestrator_tries_batch_first(monkeypatch: pytest.MonkeyPatch) -> None:
"""When batch LLM is enabled and LLM is available, batch path should be attempted."""
monkeypatch.setattr(
"app.agents.memoir.orchestrator.memoir.phase1_batch_llm_enabled",
True,
)
orch = MemoirOrchestrator()
batch_called = {"flag": False}
def fake_batch(*args, **kwargs):
batch_called["flag"] = True
return MagicMock(
state=MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
),
category_to_segments={},
segment_skip_story_ids=set(),
segment_chapter_category={},
)
orch._prepare_batches_via_batch_llm = fake_batch
class _Seg:
def __init__(self, sid: str) -> None:
self.id = sid
self.user_input_text = "test"
st = MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
)
orch.prepare_batches(
segments=[_Seg("s1")],
llm=MagicMock(),
llm_fast=MagicMock(),
get_or_create_state=lambda: st,
update_slot=lambda *a: st,
)
assert batch_called["flag"] is True
def test_orchestrator_fallback_to_sequential(monkeypatch: pytest.MonkeyPatch) -> None:
"""If batch path raises, should fall back to sequential extraction."""
monkeypatch.setattr(
"app.agents.memoir.orchestrator.memoir.phase1_batch_llm_enabled",
True,
)
orch = MemoirOrchestrator()
def fail_batch(*args, **kwargs):
raise RuntimeError("batch LLM unavailable")
orch._prepare_batches_via_batch_llm = fail_batch
orch.extraction_agent.extract = MagicMock(
return_value=ExtractionResult(
detected_stage="childhood",
slots={"place": "潍坊"},
)
)
orch.classification_agent.classify = MagicMock(
return_value=ChapterClassifyResult(category="childhood", llm_said_none=False)
)
st = MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
)
class _Seg:
def __init__(self, sid: str, text: str) -> None:
self.id = sid
self.user_input_text = text
result = orch.prepare_batches(
segments=[_Seg("s1", "我小时候玩球")],
llm=MagicMock(),
llm_fast=MagicMock(),
get_or_create_state=lambda: st,
update_slot=lambda *a: st,
)
assert "s1" in result.segment_chapter_category
def test_orchestrator_sequential_filters_invalid_slots(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Sequential fallback should match batch path slot validation."""
monkeypatch.setattr(
"app.agents.memoir.orchestrator.memoir.phase1_batch_llm_enabled",
False,
)
orch = MemoirOrchestrator()
orch.extraction_agent.extract = MagicMock(
return_value=ExtractionResult(
detected_stage="childhood",
slots={"place": "潍坊", "hallucinated": "bad"},
)
)
orch.classification_agent.classify = MagicMock(
return_value=ChapterClassifyResult(category="childhood", llm_said_none=False)
)
st = MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
)
calls: list[tuple] = []
class _Seg:
id = "s1"
user_input_text = "我小时候在潍坊。"
def update_slot(*args):
calls.append(args)
return st
orch.prepare_batches(
segments=[_Seg()],
llm=MagicMock(),
get_or_create_state=lambda: st,
update_slot=update_slot,
)
assert calls == [("childhood", "place", "潍坊", ["s1"])]
# ---------------------------------------------------------------------------
# Memory enrichment decoupled from ingest
# ---------------------------------------------------------------------------
def test_memory_service_exposes_async_batch_ingest_only() -> None:
"""MemoryService owns async ingest; legacy sync helpers stay deleted."""
import inspect
import app.features.memory.service as memory_service_module
from app.features.memory.service import MemoryService
assert hasattr(MemoryService, "ingest_transcripts_batch")
legacy_one = "ingest_transcript" + "_sync"
legacy_batch = "ingest_transcripts_batch" + "_sync"
assert not hasattr(memory_service_module, legacy_one)
assert not hasattr(memory_service_module, legacy_batch)
source = inspect.getsource(MemoryService.ingest_transcripts_batch)
assert "MemoryIngestService" in source
# ---------------------------------------------------------------------------
# Post-commit unified fan-out
# ---------------------------------------------------------------------------
def test_post_commit_result_includes_quality_pass() -> None:
"""PostCommitResult should have quality_pass_scheduled field."""
from app.features.story.post_commit import PostCommitResult
r = PostCommitResult()
assert hasattr(r, "quality_pass_scheduled")
assert r.quality_pass_scheduled is False
def test_post_commit_signature_accepts_quality_pass() -> None:
"""enqueue_story_post_commit_effects should accept need_quality_pass kwarg."""
import inspect
from app.features.story.post_commit import enqueue_story_post_commit_effects
sig = inspect.signature(enqueue_story_post_commit_effects)
assert "need_quality_pass" in sig.parameters
assert "memoir_correlation_id" in sig.parameters
# ---------------------------------------------------------------------------
# resolve_append_target
# ---------------------------------------------------------------------------
def test_resolve_append_target_forced_new_on_overflow() -> None:
"""When canonical exceeds limit, should force new story."""
from app.features.memoir.story_pipeline_sync import _resolve_append_target
session = MagicMock()
big_story = MagicMock()
big_story.user_id = "u1"
big_story.id = "story-1"
big_story.canonical_markdown = "x" * 200_000
session.get.return_value = big_story
with patch(
"app.features.memoir.story_pipeline_sync.count_story_versions_sync",
return_value=1,
):
tid, existing, dsrc = _resolve_append_target(
session,
route_decision="append_story",
route_target_story_id="story-1",
user_id="u1",
chapter_category="childhood",
oral_norm="short text",
candidate_stories=[],
story_meta={},
decision_source="test",
memoir_correlation_id=None,
)
assert tid is None
assert dsrc == "forced_new_due_to_append_limit"
def test_resolve_append_target_does_not_guardrail_route_fallback() -> None:
"""No-LLM / parse fallback new_story decisions must not append by recency."""
from app.features.memoir.story_pipeline_sync import _resolve_append_target
session = MagicMock()
candidate = MagicMock()
candidate.id = "story-1"
tid, existing, dsrc = _resolve_append_target(
session,
route_decision="new_story",
route_target_story_id=None,
user_id="u1",
chapter_category="childhood",
oral_norm="short text",
candidate_stories=[candidate],
story_meta={"story-1": {"char_count": 10, "version_count": 1}},
decision_source="no_llm",
memoir_correlation_id=None,
)
assert tid is None
assert existing == ""
assert dsrc == "no_llm"
session.get.assert_not_called()
# ---------------------------------------------------------------------------
# _run_post_pipeline_commit helper
# ---------------------------------------------------------------------------
def test_run_post_pipeline_commit_calls_post_commit() -> None:
"""Shared helper should call enqueue_story_post_commit_effects."""
from app.tasks.memoir_tasks import _run_post_pipeline_commit
with (
patch(
"app.features.story.post_commit.enqueue_story_post_commit_effects"
) as mock_pc,
patch(
"app.features.memoir.memoir_images.settings.MemoirImageSettings"
) as mock_img,
):
mock_pc.return_value = MagicMock(
enqueued_story_image_count=0,
enqueued_chapter_recompose_count=0,
compaction_scheduled=False,
quality_pass_scheduled=True,
errors=[],
)
mock_img.from_env.return_value = MagicMock(enabled=False)
_run_post_pipeline_commit(
user_id="u1",
story_dispatch_ids={"s1"},
recompose_chapter_ids={"c1"},
cover_chapter_ids=set(),
trigger_source="test",
need_compaction=False,
need_quality_pass=True,
memoir_correlation_id="cid-1",
)
mock_pc.assert_called_once()
call_kwargs = mock_pc.call_args
assert call_kwargs.kwargs["need_quality_pass"] is True
assert call_kwargs.kwargs["memoir_correlation_id"] == "cid-1"