Files
life-echo/api/app/features/memory/retriever.py
Kevin a3f61fcc0f feat(api+app): 对话阶段化、回忆录流水线与客户端会话体验
- DB: segments 用户输入文本(Alembic 0002)
- Chat: 阶段检测/阶段提示/回复限制,编排与访谈/画像 prompts 调整
- Memoir: 忠实度检查 agent,叙事与分类等链路更新
- Core: agent 日志、Alembic 启动、LangChain/日志/配置等
- Story: time_hints;Memory 检索与相关测试
- Expo: 助手头像、会话页与消息拆分、实时会话与文案/i18n
- Docs/scripts/tests: 迁移脚本、LLM JSON/记忆检索文档、新增单测
2026-03-26 12:13:36 +08:00

106 lines
3.4 KiB
Python

"""Hybrid retriever — metadata filter + FTS + vector retrieval + score fusion."""
from sqlalchemy.ext.asyncio import AsyncSession
from app.features.memory.repo import (
get_facts_for_user,
get_timeline_events_for_user,
search_chunks_fts,
search_chunks_vector,
)
from app.ports.embedding import EmbeddingProvider
def _rrf_merge(
fts_items: list[dict], vector_items: list[dict], k: int = 60
) -> list[dict]:
"""Reciprocal Rank Fusion. Merge FTS and vector results by id."""
scores: dict[str, float] = {}
for rank, item in enumerate(fts_items):
cid = item["id"]
scores[cid] = scores.get(cid, 0) + 1 / (k + rank + 1)
for rank, item in enumerate(vector_items):
cid = item["id"]
scores[cid] = scores.get(cid, 0) + 1 / (k + rank + 1)
all_items = {x["id"]: x for x in fts_items + vector_items}
sorted_ids = sorted(scores.keys(), key=lambda i: scores[i], reverse=True)
return [all_items[i] for i in sorted_ids]
class HybridRetriever:
"""Combine FTS, vector, and metadata filter into evidence bundle."""
def __init__(
self,
db: AsyncSession,
*,
embedding_provider: EmbeddingProvider | None = None,
):
self._db = db
self._embedding = embedding_provider
async def retrieve(self, user_id: str, query: str, *, top_k: int = 10) -> dict:
"""
Return evidence bundle:
{relevant_chunks, relevant_summaries, relevant_facts, timeline_hints, relevant_stories}
`relevant_summaries` / `relevant_stories` 当前多为占位空列表;叙事 prompt 仅应依赖
已实现填充的字段(见 `format_evidence_chunks_for_prompt`)。
"""
fts_chunks = await search_chunks_fts(
self._db, user_id=user_id, query=query, limit=top_k * 2
)
vector_chunks: list[dict] = []
if self._embedding and query.strip():
q_emb = await self._embedding.embed_text(query.strip())
if q_emb:
vector_chunks = await search_chunks_vector(
self._db, user_id=user_id, query_embedding=q_emb, limit=top_k * 2
)
merged = _rrf_merge(fts_chunks, vector_chunks)[:top_k]
relevant_chunks = [
{
"id": c["id"],
"content": c["content"],
"chunk_index": c.get("chunk_index", 0),
}
for c in merged
]
facts = await get_facts_for_user(self._db, user_id=user_id, limit=top_k)
relevant_facts = [
{
"id": f.id,
"fact_type": f.fact_type,
"subject": f.subject,
"predicate": f.predicate,
"object_json": f.object_json,
}
for f in facts
]
events = await get_timeline_events_for_user(
self._db, user_id=user_id, limit=top_k
)
timeline_hints = [
{
"id": e.id,
"event_year": e.event_year,
"event_date": e.event_date,
"title": e.title,
"description": e.description,
}
for e in events
]
return {
"relevant_chunks": relevant_chunks,
"relevant_summaries": [],
"relevant_facts": relevant_facts,
"timeline_hints": timeline_hints,
"relevant_stories": [],
}