69 lines
2.1 KiB
Python
69 lines
2.1 KiB
Python
"""Hybrid retriever — 向量检索 + 元数据证据包。"""
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.features.memory.evidence import retrieve_evidence_bundle_async
|
|
from app.features.memory.repo import search_chunks_vector
|
|
from app.ports.embedding import EmbeddingProvider
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class HybridRetriever:
|
|
"""向量 chunk 检索 + facts/summaries/stories。"""
|
|
|
|
def __init__(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
embedding_provider: EmbeddingProvider | None = None,
|
|
):
|
|
self._db = db
|
|
self._embedding = embedding_provider
|
|
|
|
async def retrieve(self, user_id: str, query: str, *, top_k: int = 10) -> dict:
|
|
"""
|
|
Return evidence bundle:
|
|
{relevant_chunks, relevant_summaries, relevant_facts, relevant_stories}
|
|
"""
|
|
if not query.strip():
|
|
return await retrieve_evidence_bundle_async(
|
|
self._db,
|
|
user_id,
|
|
query,
|
|
top_k=top_k,
|
|
merged_chunk_dicts=[],
|
|
)
|
|
|
|
q = query.strip()
|
|
merged_chunk_dicts: list[dict] = []
|
|
if self._embedding:
|
|
q_emb = await self._embedding.embed_text(q)
|
|
if q_emb:
|
|
vector_rows = await search_chunks_vector(
|
|
self._db, user_id, q_emb, limit=top_k
|
|
)
|
|
merged_chunk_dicts = [
|
|
{
|
|
"id": c["id"],
|
|
"content": c["content"],
|
|
"chunk_index": c.get("chunk_index", 0),
|
|
}
|
|
for c in vector_rows
|
|
]
|
|
else:
|
|
logger.warning(
|
|
"HybridRetriever empty_query_embedding user_id={}", user_id
|
|
)
|
|
else:
|
|
logger.warning("HybridRetriever no_embedding_provider user_id={}", user_id)
|
|
|
|
return await retrieve_evidence_bundle_async(
|
|
self._db,
|
|
user_id,
|
|
query,
|
|
top_k=top_k,
|
|
merged_chunk_dicts=merged_chunk_dicts,
|
|
)
|