339 lines
11 KiB
Python
339 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from app.features.memory.prompt_adapter import MemoryPromptAdapter
|
|
from app.features.memory.runtime_types import MemoryEvidenceBundle
|
|
|
|
|
|
def test_chunk_transcript_applies_configured_overlap() -> None:
|
|
from app.features.memory.chunker import chunk_transcript
|
|
|
|
text = "".join(str(i % 10) for i in range(250))
|
|
chunks = chunk_transcript(text, max_chars=100, overlap_chars=20)
|
|
|
|
assert len(chunks) >= 3
|
|
assert chunks[0][-20:] == chunks[1][:20]
|
|
assert chunks[1][-20:] == chunks[2][:20]
|
|
|
|
|
|
def test_memory_evidence_bundle_and_prompt_adapter_contract() -> None:
|
|
evidence = MemoryEvidenceBundle.from_mapping(
|
|
{
|
|
"relevant_chunks": [
|
|
{"id": "c1", "content": "我小时候在河边长大,夏天常去玩水。"},
|
|
],
|
|
"relevant_summaries": [],
|
|
"relevant_facts": [],
|
|
"relevant_stories": [],
|
|
}
|
|
)
|
|
|
|
slices = MemoryPromptAdapter().slice_for_interview(
|
|
evidence,
|
|
"那条河一到夏天就特别热闹,我现在都记得。",
|
|
)
|
|
|
|
assert evidence.has_any is True
|
|
assert slices.had_retrieval is True
|
|
assert "用户曾说" in slices.prompt_excerpt
|
|
assert slices.anchor_source.startswith("用户曾说")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_retrieval_service_delegates_to_retriever(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
from app.features.memory import retrieval_service as retrieval_mod
|
|
from app.features.memory.retrieval_service import MemoryRetrievalService
|
|
|
|
calls: list[dict] = []
|
|
|
|
class FakeRetriever:
|
|
def __init__(self, db, *, embedding_provider=None) -> None:
|
|
calls.append({"db": db, "embedding_provider": embedding_provider})
|
|
|
|
async def retrieve(self, *, user_id: str, query: str, top_k: int) -> dict:
|
|
calls.append({"user_id": user_id, "query": query, "top_k": top_k})
|
|
return {
|
|
"relevant_chunks": [{"id": "c1", "content": "chunk"}],
|
|
"relevant_summaries": [],
|
|
"relevant_facts": [],
|
|
"relevant_stories": [],
|
|
}
|
|
|
|
class FakeEmbedding:
|
|
def is_available(self) -> bool:
|
|
return True
|
|
|
|
db = object()
|
|
embedding = FakeEmbedding()
|
|
monkeypatch.setattr(retrieval_mod, "HybridRetriever", FakeRetriever)
|
|
|
|
bundle = await MemoryRetrievalService(
|
|
db,
|
|
embedding_provider=embedding,
|
|
).retrieve("user-1", "hello", top_k=3)
|
|
|
|
assert calls == [
|
|
{"db": db, "embedding_provider": embedding},
|
|
{"user_id": "user-1", "query": "hello", "top_k": 3},
|
|
]
|
|
assert bundle.relevant_chunks == [{"id": "c1", "content": "chunk"}]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_ingest_service_commits_before_enrichment(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
from app.features.memory import ingest_service as ingest_mod
|
|
from app.features.memory.ingest_service import MemoryIngestService
|
|
|
|
events: list[tuple] = []
|
|
|
|
@dataclass
|
|
class FakeRow:
|
|
id: str
|
|
|
|
class FakeDb:
|
|
async def flush(self) -> None:
|
|
events.append(("flush",))
|
|
|
|
async def commit(self) -> None:
|
|
events.append(("commit",))
|
|
|
|
class FakeEmbedding:
|
|
async def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
|
events.append(("embed_texts", tuple(texts)))
|
|
return [[1.0], [2.0]]
|
|
|
|
def is_available(self) -> bool:
|
|
return True
|
|
|
|
class FakeScheduler:
|
|
def schedule(self, request) -> str:
|
|
events.append(("schedule", request.user_id, request.source_id))
|
|
return "enrich-1"
|
|
|
|
async def fake_create_source(db, **kwargs):
|
|
events.append(("create_source", kwargs["user_id"], kwargs["conversation_id"]))
|
|
return FakeRow("source-1")
|
|
|
|
async def fake_create_chunk(db, **kwargs):
|
|
events.append(("create_chunk", kwargs["chunk_index"], kwargs["content"]))
|
|
return FakeRow(f"chunk-{kwargs['chunk_index']}")
|
|
|
|
class FakeEmbeddingService:
|
|
def __init__(self, db, *, embedding_provider=None) -> None:
|
|
events.append(("embedding_service", embedding_provider is not None))
|
|
|
|
async def embed_source(self, user_id: str, source_id: str) -> dict:
|
|
events.append(("embed_source", user_id, source_id))
|
|
return {"status": "success", "vectors_written": 2}
|
|
|
|
monkeypatch.setattr(ingest_mod, "chunk_transcript", lambda text: ["a", "b"])
|
|
monkeypatch.setattr(ingest_mod, "create_source", fake_create_source)
|
|
monkeypatch.setattr(ingest_mod, "create_chunk", fake_create_chunk)
|
|
monkeypatch.setattr(ingest_mod, "MemoryEmbeddingService", FakeEmbeddingService)
|
|
|
|
source_id = await MemoryIngestService(
|
|
FakeDb(),
|
|
embedding_provider=FakeEmbedding(),
|
|
enrichment_scheduler=FakeScheduler(),
|
|
).ingest_transcript("user-1", "conv-1", "hello")
|
|
|
|
assert source_id == "source-1"
|
|
assert events.index(("commit",)) < events.index(
|
|
("schedule", "user-1", "source-1")
|
|
)
|
|
assert ("embed_source", "user-1", "source-1") in events
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_ingest_succeeds_and_retries_when_embedding_fails(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
from app.features.memory import ingest_service as ingest_mod
|
|
from app.features.memory.ingest_service import MemoryIngestService
|
|
|
|
events: list[tuple] = []
|
|
|
|
@dataclass
|
|
class FakeRow:
|
|
id: str
|
|
|
|
class FakeDb:
|
|
async def flush(self) -> None:
|
|
events.append(("flush",))
|
|
|
|
async def commit(self) -> None:
|
|
events.append(("commit",))
|
|
|
|
class FakeEmbeddingService:
|
|
def __init__(self, db, *, embedding_provider=None) -> None:
|
|
pass
|
|
|
|
async def embed_source(self, user_id: str, source_id: str) -> dict:
|
|
events.append(("embed_source", user_id, source_id))
|
|
return {"status": "failed", "error": "upstream_timeout"}
|
|
|
|
class FakeEmbeddingScheduler:
|
|
def schedule(self, request) -> str:
|
|
events.append(("embed_retry", request.user_id, request.source_id))
|
|
return "embed-retry-1"
|
|
|
|
class FakeEmbedding:
|
|
def is_available(self) -> bool:
|
|
return True
|
|
|
|
class FakeEnrichmentScheduler:
|
|
def schedule(self, request) -> str:
|
|
events.append(("enrich", request.user_id, request.source_id))
|
|
return "enrich-1"
|
|
|
|
async def fake_create_source(db, **kwargs):
|
|
events.append(("create_source", kwargs["user_id"], kwargs["conversation_id"]))
|
|
return FakeRow("source-1")
|
|
|
|
async def fake_create_chunk(db, **kwargs):
|
|
events.append(("create_chunk", kwargs["chunk_index"], kwargs["content"]))
|
|
return FakeRow(f"chunk-{kwargs['chunk_index']}")
|
|
|
|
monkeypatch.setattr(ingest_mod, "chunk_transcript", lambda text: ["a"])
|
|
monkeypatch.setattr(ingest_mod, "create_source", fake_create_source)
|
|
monkeypatch.setattr(ingest_mod, "create_chunk", fake_create_chunk)
|
|
monkeypatch.setattr(ingest_mod, "MemoryEmbeddingService", FakeEmbeddingService)
|
|
|
|
source_id = await MemoryIngestService(
|
|
FakeDb(),
|
|
embedding_provider=FakeEmbedding(),
|
|
embedding_scheduler=FakeEmbeddingScheduler(),
|
|
enrichment_scheduler=FakeEnrichmentScheduler(),
|
|
).ingest_transcript("user-1", "conv-1", "hello")
|
|
|
|
assert source_id == "source-1"
|
|
assert ("embed_retry", "user-1", "source-1") in events
|
|
assert ("enrich", "user-1", "source-1") in events
|
|
assert events.index(("commit",)) < events.index(
|
|
("embed_source", "user-1", "source-1")
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exclude_chunk_stales_derived_facts(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
from app.features.memory import service as service_mod
|
|
from app.features.memory.service import MemoryService
|
|
|
|
events: list[tuple] = []
|
|
|
|
class FakeDb:
|
|
async def commit(self) -> None:
|
|
events.append(("commit",))
|
|
|
|
async def fake_set_chunk_excluded(db, chunk_id, user_id, excluded):
|
|
events.append(("set_excluded", chunk_id, user_id, excluded))
|
|
return True
|
|
|
|
async def fake_stale(db, *, user_id, chunk_id):
|
|
events.append(("stale_facts", user_id, chunk_id))
|
|
return 2
|
|
|
|
async def fake_curation(db, **kwargs):
|
|
events.append(("curation", kwargs))
|
|
|
|
monkeypatch.setattr(service_mod, "set_chunk_excluded", fake_set_chunk_excluded)
|
|
monkeypatch.setattr(
|
|
service_mod,
|
|
"mark_facts_stale_for_excluded_chunk",
|
|
fake_stale,
|
|
)
|
|
monkeypatch.setattr(service_mod, "create_curation_action", fake_curation)
|
|
|
|
ok = await MemoryService(FakeDb()).exclude_chunk(
|
|
"user-1",
|
|
"chunk-1",
|
|
reason="wrong memory",
|
|
)
|
|
|
|
assert ok is True
|
|
assert ("stale_facts", "user-1", "chunk-1") in events
|
|
curation = [ev for ev in events if ev[0] == "curation"][0][1]
|
|
assert curation["details"] == {
|
|
"reason": "wrong memory",
|
|
"staled_fact_count": 2,
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_restore_chunk_records_reenrichment_policy(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
from app.features.memory import service as service_mod
|
|
from app.features.memory.service import MemoryService
|
|
|
|
captured: list[dict] = []
|
|
|
|
class FakeDb:
|
|
async def commit(self) -> None:
|
|
pass
|
|
|
|
async def fake_set_chunk_excluded(db, chunk_id, user_id, excluded):
|
|
return True
|
|
|
|
async def fake_curation(db, **kwargs):
|
|
captured.append(kwargs)
|
|
|
|
monkeypatch.setattr(service_mod, "set_chunk_excluded", fake_set_chunk_excluded)
|
|
monkeypatch.setattr(service_mod, "create_curation_action", fake_curation)
|
|
|
|
ok = await MemoryService(FakeDb()).restore_chunk("user-1", "chunk-1")
|
|
|
|
assert ok is True
|
|
assert captured[0]["details"] == {"fact_restore_policy": "requires_reenrichment"}
|
|
|
|
|
|
def test_memory_single_chain_architecture_guard() -> None:
|
|
"""Keep removed memory compatibility paths from creeping back in."""
|
|
repo_root = Path(__file__).resolve().parents[2]
|
|
banned = [
|
|
"retrieve_evidence" + "_sync",
|
|
"retrieve_evidence_bundle" + "_sync",
|
|
"ingest_transcript" + "_sync",
|
|
"ingest_transcripts_batch" + "_sync",
|
|
"lineage" + "_tier=" + '"fallback"',
|
|
"lineage" + "_tier=" + "'fallback'",
|
|
"evidence_bundle" + "_json",
|
|
"memory_fact_search_use_recent" + "_fallback",
|
|
"memory_evidence_empty_query_include" + "_rolling",
|
|
"_interview_meta" + "_store",
|
|
"timeline" + "_hints",
|
|
"parse_json" + "_payload",
|
|
"from app.agents.memoir.prompts import "
|
|
"format_evidence_chunks_for_prompt",
|
|
]
|
|
roots = [
|
|
repo_root / "api" / "app",
|
|
repo_root / "api" / "tests",
|
|
repo_root / "api" / "docs",
|
|
]
|
|
files: list[Path] = []
|
|
for root in roots:
|
|
files.extend(
|
|
p
|
|
for p in root.rglob("*")
|
|
if p.is_file() and p.suffix in {".py", ".md", ".txt"}
|
|
)
|
|
files.extend(p for p in (repo_root / "api").glob(".env*") if p.is_file())
|
|
|
|
hits: list[str] = []
|
|
for path in files:
|
|
text = path.read_text(encoding="utf-8")
|
|
for needle in banned:
|
|
if needle in text:
|
|
hits.append(f"{path.relative_to(repo_root)}: {needle}")
|
|
|
|
assert hits == []
|