Simplify AI memory pipeline

This commit is contained in:
Kevin
2026-04-30 16:22:55 +08:00
parent 7617ea902c
commit 3234396254
35 changed files with 1002 additions and 579 deletions

View File

@@ -27,6 +27,19 @@ async def test_memory_ingest_passes_lineage(monkeypatch) -> None:
captured["scheduled"] = request
return "task-1"
class FakeEmbeddingScheduler:
def schedule(self, request):
captured["embedding_scheduled"] = request
return "embedding-task-1"
class FakeEmbeddingService:
def __init__(self, *_args, **_kwargs) -> None:
pass
async def embed_source(self, user_id: str, source_id: str) -> dict:
captured["embedded"] = (user_id, source_id)
return {"status": "success", "vectors_written": 1}
async def fake_create_source(session, **kwargs):
captured.update(kwargs)
return SimpleNamespace(id="src-1")
@@ -42,6 +55,10 @@ async def test_memory_ingest_passes_lineage(monkeypatch) -> None:
"app.features.memory.ingest_service.create_chunk",
fake_create_chunk,
)
monkeypatch.setattr(
"app.features.memory.ingest_service.MemoryEmbeddingService",
FakeEmbeddingService,
)
monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", False)
lineage = {
@@ -57,6 +74,7 @@ async def test_memory_ingest_passes_lineage(monkeypatch) -> None:
service = MemoryIngestService(
fake_session, # type: ignore[arg-type]
embedding_provider=None,
embedding_scheduler=FakeEmbeddingScheduler(), # type: ignore[arg-type]
enrichment_scheduler=FakeScheduler(), # type: ignore[arg-type]
)
sid = await service.ingest_transcript(

View File

@@ -1,20 +1,14 @@
"""JSON 载荷解析、证据格式化、Story 批量规划校验(纯函数)。"""
import pytest
from app.agents.chat.reply_limits import truncate_chat_segments
from app.agents.memoir.classification_agent import _normalize_llm_category
from app.agents.memoir.prompts import format_evidence_chunks_for_prompt
from app.features.memory.evidence_format import (
format_evidence_chunks_for_prompt as format_evidence_from_memory,
)
from app.agents.memoir.story_route_agent import (
StoryBatchPlan,
StoryBatchPlanUnit,
validate_story_batch_plan,
)
from app.core.json_utils import extract_json_payload
from app.features.memory.evidence_format import format_evidence_chunks_for_prompt
def test_extract_json_payload_strips_markdown_fence() -> None:
@@ -34,29 +28,19 @@ def test_normalize_llm_category_strips_quotes() -> None:
assert _normalize_llm_category("`beliefs`") == "beliefs"
def test_format_evidence_chunks_includes_timeline() -> None:
def test_format_evidence_chunks_uses_memory_formatter_without_timeline() -> None:
ev = {
"relevant_chunks": [{"content": "chunk1"}],
"relevant_facts": [
{"subject": "", "predicate": "生于", "object_json": "1950"}
],
"timeline_hints": [
{
"id": "1",
"event_year": 1977,
"event_date": None,
"title": "恢复高考",
"description": "参加了考试",
}
],
"relevant_summaries": [],
"relevant_stories": [],
}
out = format_evidence_chunks_for_prompt(ev)
assert "chunk1" in out
assert "1950" in out or "生于" in out
assert "1977" in out or "恢复高考" in out
assert format_evidence_from_memory(ev) == out
assert "恢复高考" not in out
def test_validate_story_batch_plan_ok() -> None:

View File

@@ -9,6 +9,17 @@ from app.features.memory.prompt_adapter import MemoryPromptAdapter
from app.features.memory.runtime_types import MemoryEvidenceBundle
def test_chunk_transcript_applies_configured_overlap() -> None:
from app.features.memory.chunker import chunk_transcript
text = "".join(str(i % 10) for i in range(250))
chunks = chunk_transcript(text, max_chars=100, overlap_chars=20)
assert len(chunks) >= 3
assert chunks[0][-20:] == chunks[1][:20]
assert chunks[1][-20:] == chunks[2][:20]
def test_memory_evidence_bundle_and_prompt_adapter_contract() -> None:
evidence = MemoryEvidenceBundle.from_mapping(
{
@@ -17,7 +28,6 @@ def test_memory_evidence_bundle_and_prompt_adapter_contract() -> None:
],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
)
@@ -52,7 +62,6 @@ async def test_memory_retrieval_service_delegates_to_retriever(
"relevant_chunks": [{"id": "c1", "content": "chunk"}],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
@@ -117,17 +126,18 @@ async def test_memory_ingest_service_commits_before_enrichment(
events.append(("create_chunk", kwargs["chunk_index"], kwargs["content"]))
return FakeRow(f"chunk-{kwargs['chunk_index']}")
async def fake_update_chunk_embedding(db, chunk_id, emb):
events.append(("update_embedding", chunk_id, tuple(emb)))
class FakeEmbeddingService:
def __init__(self, db, *, embedding_provider=None) -> None:
events.append(("embedding_service", embedding_provider is not None))
async def embed_source(self, user_id: str, source_id: str) -> dict:
events.append(("embed_source", user_id, source_id))
return {"status": "success", "vectors_written": 2}
monkeypatch.setattr(ingest_mod, "chunk_transcript", lambda text: ["a", "b"])
monkeypatch.setattr(ingest_mod, "create_source", fake_create_source)
monkeypatch.setattr(ingest_mod, "create_chunk", fake_create_chunk)
monkeypatch.setattr(
ingest_mod,
"update_chunk_embedding",
fake_update_chunk_embedding,
)
monkeypatch.setattr(ingest_mod, "MemoryEmbeddingService", FakeEmbeddingService)
source_id = await MemoryIngestService(
FakeDb(),
@@ -139,9 +149,150 @@ async def test_memory_ingest_service_commits_before_enrichment(
assert events.index(("commit",)) < events.index(
("schedule", "user-1", "source-1")
)
assert ("embed_texts", ("a", "b")) in events
assert ("update_embedding", "chunk-0", (1.0,)) in events
assert ("update_embedding", "chunk-1", (2.0,)) in events
assert ("embed_source", "user-1", "source-1") in events
@pytest.mark.asyncio
async def test_memory_ingest_succeeds_and_retries_when_embedding_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from app.features.memory import ingest_service as ingest_mod
from app.features.memory.ingest_service import MemoryIngestService
events: list[tuple] = []
@dataclass
class FakeRow:
id: str
class FakeDb:
async def flush(self) -> None:
events.append(("flush",))
async def commit(self) -> None:
events.append(("commit",))
class FakeEmbeddingService:
def __init__(self, db, *, embedding_provider=None) -> None:
pass
async def embed_source(self, user_id: str, source_id: str) -> dict:
events.append(("embed_source", user_id, source_id))
return {"status": "failed", "error": "upstream_timeout"}
class FakeEmbeddingScheduler:
def schedule(self, request) -> str:
events.append(("embed_retry", request.user_id, request.source_id))
return "embed-retry-1"
class FakeEmbedding:
def is_available(self) -> bool:
return True
class FakeEnrichmentScheduler:
def schedule(self, request) -> str:
events.append(("enrich", request.user_id, request.source_id))
return "enrich-1"
async def fake_create_source(db, **kwargs):
events.append(("create_source", kwargs["user_id"], kwargs["conversation_id"]))
return FakeRow("source-1")
async def fake_create_chunk(db, **kwargs):
events.append(("create_chunk", kwargs["chunk_index"], kwargs["content"]))
return FakeRow(f"chunk-{kwargs['chunk_index']}")
monkeypatch.setattr(ingest_mod, "chunk_transcript", lambda text: ["a"])
monkeypatch.setattr(ingest_mod, "create_source", fake_create_source)
monkeypatch.setattr(ingest_mod, "create_chunk", fake_create_chunk)
monkeypatch.setattr(ingest_mod, "MemoryEmbeddingService", FakeEmbeddingService)
source_id = await MemoryIngestService(
FakeDb(),
embedding_provider=FakeEmbedding(),
embedding_scheduler=FakeEmbeddingScheduler(),
enrichment_scheduler=FakeEnrichmentScheduler(),
).ingest_transcript("user-1", "conv-1", "hello")
assert source_id == "source-1"
assert ("embed_retry", "user-1", "source-1") in events
assert ("enrich", "user-1", "source-1") in events
assert events.index(("commit",)) < events.index(
("embed_source", "user-1", "source-1")
)
@pytest.mark.asyncio
async def test_exclude_chunk_stales_derived_facts(monkeypatch: pytest.MonkeyPatch) -> None:
from app.features.memory import service as service_mod
from app.features.memory.service import MemoryService
events: list[tuple] = []
class FakeDb:
async def commit(self) -> None:
events.append(("commit",))
async def fake_set_chunk_excluded(db, chunk_id, user_id, excluded):
events.append(("set_excluded", chunk_id, user_id, excluded))
return True
async def fake_stale(db, *, user_id, chunk_id):
events.append(("stale_facts", user_id, chunk_id))
return 2
async def fake_curation(db, **kwargs):
events.append(("curation", kwargs))
monkeypatch.setattr(service_mod, "set_chunk_excluded", fake_set_chunk_excluded)
monkeypatch.setattr(
service_mod,
"mark_facts_stale_for_excluded_chunk",
fake_stale,
)
monkeypatch.setattr(service_mod, "create_curation_action", fake_curation)
ok = await MemoryService(FakeDb()).exclude_chunk(
"user-1",
"chunk-1",
reason="wrong memory",
)
assert ok is True
assert ("stale_facts", "user-1", "chunk-1") in events
curation = [ev for ev in events if ev[0] == "curation"][0][1]
assert curation["details"] == {
"reason": "wrong memory",
"staled_fact_count": 2,
}
@pytest.mark.asyncio
async def test_restore_chunk_records_reenrichment_policy(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from app.features.memory import service as service_mod
from app.features.memory.service import MemoryService
captured: list[dict] = []
class FakeDb:
async def commit(self) -> None:
pass
async def fake_set_chunk_excluded(db, chunk_id, user_id, excluded):
return True
async def fake_curation(db, **kwargs):
captured.append(kwargs)
monkeypatch.setattr(service_mod, "set_chunk_excluded", fake_set_chunk_excluded)
monkeypatch.setattr(service_mod, "create_curation_action", fake_curation)
ok = await MemoryService(FakeDb()).restore_chunk("user-1", "chunk-1")
assert ok is True
assert captured[0]["details"] == {"fact_restore_policy": "requires_reenrichment"}
def test_memory_single_chain_architecture_guard() -> None:
@@ -158,6 +309,10 @@ def test_memory_single_chain_architecture_guard() -> None:
"memory_fact_search_use_recent" + "_fallback",
"memory_evidence_empty_query_include" + "_rolling",
"_interview_meta" + "_store",
"timeline" + "_hints",
"parse_json" + "_payload",
"from app.agents.memoir.prompts import "
"format_evidence_chunks_for_prompt",
]
roots = [
repo_root / "api" / "app",

View File

@@ -7,7 +7,7 @@ from types import SimpleNamespace
import pytest
from app.features.memory.enrichment import enrich_memory_after_ingest_async
from app.features.memory.llm_schemas import EnrichmentPayload, parse_json_payload
from app.features.memory.llm_schemas import EnrichmentPayload
from app.features.memory.models import MemorySource
from app.features.user.models import User
@@ -19,8 +19,7 @@ def test_enrichment_payload_roundtrip() -> None:
'"object_json":{"value":"北京","approximate_era":"1990年代"},'
'"confidence":0.85,"source_chunk_id":"ch-1"}]}'
)
p = parse_json_payload(raw, EnrichmentPayload)
assert p is not None
p = EnrichmentPayload.model_validate_json(raw)
assert p.summary == "要点摘要"
assert len(p.facts) == 1
assert p.facts[0].subject == "王伟"
@@ -36,16 +35,27 @@ async def test_enrich_memory_after_ingest_async_single_llm_call(
invoke_count = {"n": 0}
async def fake_invoke(llm, prompt, max_tokens, agent):
async def fake_run(llm, numbered, narrator_label):
invoke_count["n"] += 1
assert agent == "memory.enrichment"
return (
'{"summary":"本轮要点",'
'"facts":[{"fact_type":"event","subject":"王伟","predicate":"",'
'"object_json":{"value":"上海"},"confidence":0.8,"source_chunk_id":"ch1"}]}'
assert "[chunk_id=ch1]" in numbered
assert narrator_label == "老王"
return EnrichmentPayload.model_validate(
{
"summary": "本轮要点",
"facts": [
{
"fact_type": "event",
"subject": "王伟",
"predicate": "",
"object_json": {"value": "上海"},
"confidence": 0.8,
"source_chunk_id": "ch1",
}
],
}
)
monkeypatch.setattr(mod, "ainvoke_json_object", fake_invoke)
monkeypatch.setattr(mod, "_run_enrichment_llm_async", fake_run)
summaries: list[dict] = []
facts: list[dict] = []
@@ -74,7 +84,7 @@ async def test_enrich_memory_after_ingest_async_single_llm_call(
if model is User and key == "u1":
return SimpleNamespace(nickname="老王")
if model is MemorySource and key == "src-1":
return SimpleNamespace(lineage_json=None)
return SimpleNamespace(user_id="u1", lineage_json=None)
return None
async def execute(self, _stmt):
@@ -105,10 +115,10 @@ async def test_enrich_memory_skips_when_parse_returns_none(
monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", True)
async def fake_invoke(*_args, **_kwargs):
return "{not json"
async def fake_run(*_args, **_kwargs):
return None
monkeypatch.setattr(mod, "ainvoke_json_object", fake_invoke)
monkeypatch.setattr(mod, "_run_enrichment_llm_async", fake_run)
called = {"summary": False, "fact": False}
async def capture_summary(*_args, **_kwargs):
@@ -135,7 +145,7 @@ async def test_enrich_memory_skips_when_parse_returns_none(
if model is User and key == "u":
return None
if model is MemorySource and key == "s":
return SimpleNamespace(lineage_json=None)
return SimpleNamespace(user_id="u", lineage_json=None)
return None
async def execute(self, _stmt):

View File

@@ -7,7 +7,6 @@ from app.features.memory.evidence import (
EMPTY_EVIDENCE_BUNDLE,
_facts_to_dicts,
_stories_to_dicts,
_timeline_to_dicts,
retrieve_evidence_bundle_async,
)
from app.features.memory.evidence_format import format_evidence_chunks_for_chat_prompt
@@ -19,7 +18,6 @@ def test_empty_evidence_bundle_keys() -> None:
"relevant_chunks",
"relevant_summaries",
"relevant_facts",
"timeline_hints",
"relevant_stories",
}
@@ -31,7 +29,6 @@ def test_evidence_bundle_model_accepts_dict() -> None:
def test_format_helpers_empty() -> None:
assert _facts_to_dicts([]) == []
assert _timeline_to_dicts([]) == []
assert _stories_to_dicts([]) == []
@@ -42,7 +39,6 @@ def test_format_evidence_chunks_for_chat_prompt_reframes_and_labels() -> None:
],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
text = format_evidence_chunks_for_chat_prompt(evidence)
@@ -73,7 +69,6 @@ def test_slice_interview_memory_retrieval_not_equal_inject_dismissive():
],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
s = slice_interview_memory(evidence, "哈哈,早就不会了")
@@ -92,7 +87,6 @@ def test_slice_interview_memory_minimal_inject_when_aligned():
],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
s = slice_interview_memory(evidence, "那次排练其实挺紧张的,灯光一打我就忘词。")
@@ -111,7 +105,6 @@ def test_slice_interview_memory_keeps_first_person_but_marks_ownership():
],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
s = slice_interview_memory(evidence, "那条河一到夏天就特别热闹,我现在都记得。")
@@ -129,7 +122,6 @@ def test_slice_interview_memory_suppresses_long_new_topic():
],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
long_msg = "我今天想随便聊聊工作里的事,项目压力很大。" * 6
@@ -153,7 +145,6 @@ async def test_retrieve_evidence_bundle_async_non_empty_merges_precomputed_chunk
"object_json": {},
}
],
"timeline_hints": [],
"relevant_summaries": [
{
"id": "s1",

View File

@@ -4,7 +4,6 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from app.agents.memoir.prompts import format_evidence_chunks_for_prompt
from app.agents.memoir.story_route_agent import StoryRouteDecision
from app.agents.state_schema import MemoirStateSchema
from app.features.asset import models as _asset_models # noqa: F401
@@ -15,6 +14,7 @@ from app.features.memoir.story_pipeline_sync import (
run_story_pipeline_for_category_batch,
)
from app.features.memory import models as _memory_models # noqa: F401
from app.features.memory.evidence_format import format_evidence_chunks_for_prompt
from app.features.payment import models as _payment_models # noqa: F401
from app.features.story import models as _story_models # noqa: F401
from app.features.user import models as _user_models # noqa: F401
@@ -47,7 +47,6 @@ def test_single_segment_decide_receives_only_combined_text_not_evidence() -> Non
}
],
"relevant_facts": [{"subject": "X", "predicate": "y", "object_json": {}}],
"timeline_hints": [],
"relevant_stories": [],
}
evidence_formatted = format_evidence_chunks_for_prompt(evidence_payload)
@@ -236,7 +235,6 @@ def test_decide_receives_only_same_stage_story_candidates() -> None:
"relevant_chunks": [],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
},
)