2026-03-20 10:30:07 +08:00
|
|
|
"""Hybrid retriever — metadata filter + FTS + vector retrieval + score fusion."""
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
2026-03-27 16:01:28 +08:00
|
|
|
from app.features.memory.evidence import retrieve_evidence_bundle_async
|
|
|
|
|
from app.features.memory.repo import search_chunks_fts, search_chunks_vector
|
2026-03-20 10:30:07 +08:00
|
|
|
from app.ports.embedding import EmbeddingProvider
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _rrf_merge(
|
|
|
|
|
fts_items: list[dict], vector_items: list[dict], k: int = 60
|
|
|
|
|
) -> list[dict]:
|
|
|
|
|
"""Reciprocal Rank Fusion. Merge FTS and vector results by id."""
|
|
|
|
|
scores: dict[str, float] = {}
|
|
|
|
|
for rank, item in enumerate(fts_items):
|
|
|
|
|
cid = item["id"]
|
|
|
|
|
scores[cid] = scores.get(cid, 0) + 1 / (k + rank + 1)
|
|
|
|
|
for rank, item in enumerate(vector_items):
|
|
|
|
|
cid = item["id"]
|
|
|
|
|
scores[cid] = scores.get(cid, 0) + 1 / (k + rank + 1)
|
|
|
|
|
|
|
|
|
|
all_items = {x["id"]: x for x in fts_items + vector_items}
|
|
|
|
|
sorted_ids = sorted(scores.keys(), key=lambda i: scores[i], reverse=True)
|
|
|
|
|
return [all_items[i] for i in sorted_ids]
|
2026-03-18 17:18:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridRetriever:
|
2026-03-20 10:30:07 +08:00
|
|
|
"""Combine FTS, vector, and metadata filter into evidence bundle."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
*,
|
|
|
|
|
embedding_provider: EmbeddingProvider | None = None,
|
|
|
|
|
):
|
|
|
|
|
self._db = db
|
|
|
|
|
self._embedding = embedding_provider
|
|
|
|
|
|
|
|
|
|
async def retrieve(self, user_id: str, query: str, *, top_k: int = 10) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
Return evidence bundle:
|
|
|
|
|
{relevant_chunks, relevant_summaries, relevant_facts, timeline_hints, relevant_stories}
|
|
|
|
|
"""
|
2026-03-27 16:01:28 +08:00
|
|
|
if not query.strip():
|
|
|
|
|
return await retrieve_evidence_bundle_async(
|
|
|
|
|
self._db,
|
|
|
|
|
user_id,
|
|
|
|
|
query,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
merged_chunk_dicts=[],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
q = query.strip()
|
2026-03-20 10:30:07 +08:00
|
|
|
fts_chunks = await search_chunks_fts(
|
|
|
|
|
self._db, user_id=user_id, query=query, limit=top_k * 2
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
vector_chunks: list[dict] = []
|
2026-03-27 16:01:28 +08:00
|
|
|
if self._embedding and q:
|
|
|
|
|
q_emb = await self._embedding.embed_text(q)
|
2026-03-20 10:30:07 +08:00
|
|
|
if q_emb:
|
|
|
|
|
vector_chunks = await search_chunks_vector(
|
|
|
|
|
self._db, user_id=user_id, query_embedding=q_emb, limit=top_k * 2
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
merged = _rrf_merge(fts_chunks, vector_chunks)[:top_k]
|
2026-03-27 16:01:28 +08:00
|
|
|
merged_chunk_dicts = [
|
2026-03-20 10:30:07 +08:00
|
|
|
{
|
|
|
|
|
"id": c["id"],
|
|
|
|
|
"content": c["content"],
|
|
|
|
|
"chunk_index": c.get("chunk_index", 0),
|
|
|
|
|
}
|
|
|
|
|
for c in merged
|
|
|
|
|
]
|
|
|
|
|
|
2026-03-27 16:01:28 +08:00
|
|
|
return await retrieve_evidence_bundle_async(
|
|
|
|
|
self._db,
|
|
|
|
|
user_id,
|
|
|
|
|
query,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
merged_chunk_dicts=merged_chunk_dicts,
|
2026-03-20 10:30:07 +08:00
|
|
|
)
|