Files
life-echo/api/app/features/memory/retriever.py
2026-04-30 16:22:55 +08:00

69 lines
2.1 KiB
Python

"""Hybrid retriever — 向量检索 + 元数据证据包。"""
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.features.memory.evidence import retrieve_evidence_bundle_async
from app.features.memory.repo import search_chunks_vector
from app.ports.embedding import EmbeddingProvider
logger = get_logger(__name__)
class HybridRetriever:
"""向量 chunk 检索 + facts/summaries/stories。"""
def __init__(
self,
db: AsyncSession,
*,
embedding_provider: EmbeddingProvider | None = None,
):
self._db = db
self._embedding = embedding_provider
async def retrieve(self, user_id: str, query: str, *, top_k: int = 10) -> dict:
"""
Return evidence bundle:
{relevant_chunks, relevant_summaries, relevant_facts, relevant_stories}
"""
if not query.strip():
return await retrieve_evidence_bundle_async(
self._db,
user_id,
query,
top_k=top_k,
merged_chunk_dicts=[],
)
q = query.strip()
merged_chunk_dicts: list[dict] = []
if self._embedding:
q_emb = await self._embedding.embed_text(q)
if q_emb:
vector_rows = await search_chunks_vector(
self._db, user_id, q_emb, limit=top_k
)
merged_chunk_dicts = [
{
"id": c["id"],
"content": c["content"],
"chunk_index": c.get("chunk_index", 0),
}
for c in vector_rows
]
else:
logger.warning(
"HybridRetriever empty_query_embedding user_id={}", user_id
)
else:
logger.warning("HybridRetriever no_embedding_provider user_id={}", user_id)
return await retrieve_evidence_bundle_async(
self._db,
user_id,
query,
top_k=top_k,
merged_chunk_dicts=merged_chunk_dicts,
)