Files
life-echo/api/tests/test_memoir_pipeline_optimization.py

253 lines
8.3 KiB
Python

"""Validation tests for memoir pipeline optimization (Phase A/B/C).
Tests:
- Phase1 batch path is now the default
- Memory enrichment is dispatched asynchronously
- Unified narrative unit executor produces correct results
- Post-commit fan-out includes quality pass
- Quality pass task handles title polishing
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from app.agents.memoir.extraction_agent import ExtractionResult
from app.agents.memoir.classification_agent import ChapterClassifyResult
from app.agents.memoir.orchestrator import MemoirOrchestrator
from app.agents.state_schema import MemoirStateSchema
# ---------------------------------------------------------------------------
# Phase1 batch path defaults
# ---------------------------------------------------------------------------
def test_phase1_batch_enabled_by_default() -> None:
"""memoir_phase1_batch_llm_enabled should default to True after optimization."""
from app.core.config import Settings
s = Settings()
assert s.memoir_phase1_batch_llm_enabled is True
assert s.memoir_phase1_batch_llm_chunk_size >= 1
def test_quality_pass_enabled_by_default() -> None:
from app.core.config import Settings
s = Settings()
assert s.memoir_quality_pass_enabled is True
# ---------------------------------------------------------------------------
# Phase1 orchestrator selects batch path when available
# ---------------------------------------------------------------------------
def test_orchestrator_tries_batch_first(monkeypatch: pytest.MonkeyPatch) -> None:
"""When batch LLM is enabled and LLM is available, batch path should be attempted."""
monkeypatch.setattr(
"app.agents.memoir.orchestrator.settings.memoir_phase1_batch_llm_enabled",
True,
)
orch = MemoirOrchestrator()
batch_called = {"flag": False}
def fake_batch(*args, **kwargs):
batch_called["flag"] = True
return MagicMock(
state=MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
),
category_to_segments={},
segment_skip_story_ids=set(),
segment_chapter_category={},
)
orch._prepare_batches_via_batch_llm = fake_batch
class _Seg:
def __init__(self, sid: str) -> None:
self.id = sid
self.user_input_text = "test"
st = MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
)
orch.prepare_batches(
segments=[_Seg("s1")],
llm=MagicMock(),
llm_fast=MagicMock(),
get_or_create_state=lambda: st,
update_slot=lambda *a: st,
)
assert batch_called["flag"] is True
def test_orchestrator_fallback_to_sequential(monkeypatch: pytest.MonkeyPatch) -> None:
"""If batch path raises, should fall back to sequential extraction."""
monkeypatch.setattr(
"app.agents.memoir.orchestrator.settings.memoir_phase1_batch_llm_enabled",
True,
)
orch = MemoirOrchestrator()
def fail_batch(*args, **kwargs):
raise RuntimeError("batch LLM unavailable")
orch._prepare_batches_via_batch_llm = fail_batch
orch.extraction_agent.extract = MagicMock(
return_value=ExtractionResult(detected_stage="childhood", slots={"toy": "ball"})
)
orch.classification_agent.classify = MagicMock(
return_value=ChapterClassifyResult(category="childhood", llm_said_none=False)
)
st = MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
)
class _Seg:
def __init__(self, sid: str, text: str) -> None:
self.id = sid
self.user_input_text = text
result = orch.prepare_batches(
segments=[_Seg("s1", "我小时候玩球")],
llm=MagicMock(),
llm_fast=MagicMock(),
get_or_create_state=lambda: st,
update_slot=lambda *a: st,
)
assert "s1" in result.segment_chapter_category
# ---------------------------------------------------------------------------
# Memory enrichment decoupled from ingest
# ---------------------------------------------------------------------------
def test_ingest_transcript_sync_no_longer_calls_enrichment_inline() -> None:
"""After decoupling, ingest_transcript_sync should NOT import enrichment inline."""
import inspect
from app.features.memory.service import ingest_transcript_sync
source = inspect.getsource(ingest_transcript_sync)
assert "enrich_memory_after_ingest_sync" not in source
assert "schedule_memory_enrichment" in source
# ---------------------------------------------------------------------------
# Post-commit unified fan-out
# ---------------------------------------------------------------------------
def test_post_commit_result_includes_quality_pass() -> None:
"""PostCommitResult should have quality_pass_scheduled field."""
from app.features.story.post_commit import PostCommitResult
r = PostCommitResult()
assert hasattr(r, "quality_pass_scheduled")
assert r.quality_pass_scheduled is False
def test_post_commit_signature_accepts_quality_pass() -> None:
"""enqueue_story_post_commit_effects should accept need_quality_pass kwarg."""
import inspect
from app.features.story.post_commit import enqueue_story_post_commit_effects
sig = inspect.signature(enqueue_story_post_commit_effects)
assert "need_quality_pass" in sig.parameters
assert "memoir_correlation_id" in sig.parameters
# ---------------------------------------------------------------------------
# resolve_append_target
# ---------------------------------------------------------------------------
def test_resolve_append_target_forced_new_on_overflow() -> None:
"""When canonical exceeds limit, should force new story."""
from app.features.memoir.story_pipeline_sync import _resolve_append_target
session = MagicMock()
big_story = MagicMock()
big_story.user_id = "u1"
big_story.id = "story-1"
big_story.canonical_markdown = "x" * 200_000
session.get.return_value = big_story
with patch(
"app.features.memoir.story_pipeline_sync.count_story_versions_sync",
return_value=1,
):
tid, existing, dsrc = _resolve_append_target(
session,
route_decision="append_story",
route_target_story_id="story-1",
user_id="u1",
chapter_category="childhood",
oral_norm="short text",
candidate_stories=[],
story_meta={},
decision_source="test",
memoir_correlation_id=None,
)
assert tid is None
assert dsrc == "forced_new_due_to_append_limit"
# ---------------------------------------------------------------------------
# _run_post_pipeline_commit helper
# ---------------------------------------------------------------------------
def test_run_post_pipeline_commit_calls_post_commit() -> None:
"""Shared helper should call enqueue_story_post_commit_effects."""
from app.tasks.memoir_tasks import _run_post_pipeline_commit
with (
patch(
"app.features.story.post_commit.enqueue_story_post_commit_effects"
) as mock_pc,
patch(
"app.features.memoir.memoir_images.settings.MemoirImageSettings"
) as mock_img,
):
mock_pc.return_value = MagicMock(
enqueued_story_image_count=0,
enqueued_chapter_recompose_count=0,
compaction_scheduled=False,
quality_pass_scheduled=True,
errors=[],
)
mock_img.from_env.return_value = MagicMock(enabled=False)
_run_post_pipeline_commit(
user_id="u1",
story_dispatch_ids={"s1"},
recompose_chapter_ids={"c1"},
cover_chapter_ids=set(),
trigger_source="test",
need_compaction=False,
need_quality_pass=True,
memoir_correlation_id="cid-1",
)
mock_pc.assert_called_once()
call_kwargs = mock_pc.call_args
assert call_kwargs.kwargs["need_quality_pass"] is True
assert call_kwargs.kwargs["memoir_correlation_id"] == "cid-1"