Files
life-echo/api/tests/test_memory_enrichment_baseline.py
2026-04-30 16:22:55 +08:00

161 lines
4.8 KiB
Python

"""Baseline memory enrichment: single LLM call → session summary + facts."""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.features.memory.enrichment import enrich_memory_after_ingest_async
from app.features.memory.llm_schemas import EnrichmentPayload
from app.features.memory.models import MemorySource
from app.features.user.models import User
def test_enrichment_payload_roundtrip() -> None:
raw = (
'{"summary":"要点摘要",'
'"facts":[{"fact_type":"event","subject":"王伟","predicate":"",'
'"object_json":{"value":"北京","approximate_era":"1990年代"},'
'"confidence":0.85,"source_chunk_id":"ch-1"}]}'
)
p = EnrichmentPayload.model_validate_json(raw)
assert p.summary == "要点摘要"
assert len(p.facts) == 1
assert p.facts[0].subject == "王伟"
@pytest.mark.asyncio
async def test_enrich_memory_after_ingest_async_single_llm_call(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from app.features.memory import enrichment as mod
monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", True)
invoke_count = {"n": 0}
async def fake_run(llm, numbered, narrator_label):
invoke_count["n"] += 1
assert "[chunk_id=ch1]" in numbered
assert narrator_label == "老王"
return EnrichmentPayload.model_validate(
{
"summary": "本轮要点",
"facts": [
{
"fact_type": "event",
"subject": "王伟",
"predicate": "",
"object_json": {"value": "上海"},
"confidence": 0.8,
"source_chunk_id": "ch1",
}
],
}
)
monkeypatch.setattr(mod, "_run_enrichment_llm_async", fake_run)
summaries: list[dict] = []
facts: list[dict] = []
async def capture_summary(session, **kwargs):
summaries.append(kwargs)
async def capture_fact(session, **kwargs):
facts.append(kwargs)
monkeypatch.setattr(mod, "create_memory_summary", capture_summary)
monkeypatch.setattr(mod, "create_memory_fact", capture_fact)
class FakeResult:
def unique(self):
return self
def scalars(self):
return self
def all(self):
return [SimpleNamespace(id="ch1", content="王伟住在上海。")]
class FakeSession:
async def get(self, model, key):
if model is User and key == "u1":
return SimpleNamespace(nickname="老王")
if model is MemorySource and key == "src-1":
return SimpleNamespace(user_id="u1", lineage_json=None)
return None
async def execute(self, _stmt):
return FakeResult()
await enrich_memory_after_ingest_async(
FakeSession(), # type: ignore[arg-type]
"u1",
"src-1",
llm=object(),
)
assert invoke_count["n"] == 1
assert len(summaries) == 1
assert summaries[0]["summary_type"] == "session"
assert summaries[0]["content"] == "本轮要点"
assert summaries[0]["source_chunk_ids"] == ["ch1"]
assert len(facts) == 1
assert facts[0]["predicate"] == ""
assert facts[0]["status"] == "confirmed"
@pytest.mark.asyncio
async def test_enrich_memory_skips_when_parse_returns_none(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from app.features.memory import enrichment as mod
monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", True)
async def fake_run(*_args, **_kwargs):
return None
monkeypatch.setattr(mod, "_run_enrichment_llm_async", fake_run)
called = {"summary": False, "fact": False}
async def capture_summary(*_args, **_kwargs):
called.update(summary=True)
async def capture_fact(*_args, **_kwargs):
called.update(fact=True)
monkeypatch.setattr(mod, "create_memory_summary", capture_summary)
monkeypatch.setattr(mod, "create_memory_fact", capture_fact)
class FakeResult:
def unique(self):
return self
def scalars(self):
return self
def all(self):
return [SimpleNamespace(id="c1", content="x")]
class FakeSession:
async def get(self, model, key):
if model is User and key == "u":
return None
if model is MemorySource and key == "s":
return SimpleNamespace(user_id="u", lineage_json=None)
return None
async def execute(self, _stmt):
return FakeResult()
await enrich_memory_after_ingest_async(
FakeSession(), # type: ignore[arg-type]
"u",
"s",
llm=object(),
)
assert called == {"summary": False, "fact": False}