2026-03-20 10:30:07 +08:00
|
|
|
"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access."""
|
|
|
|
|
|
|
|
|
|
import uuid
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import select, text
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from app.features.memory.models import (
|
|
|
|
|
MemoryChunk,
|
|
|
|
|
MemoryFact,
|
|
|
|
|
MemorySource,
|
|
|
|
|
TimelineEvent,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _new_id() -> str:
|
|
|
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_source_sync(
|
|
|
|
|
session: Session,
|
|
|
|
|
*,
|
|
|
|
|
user_id: str,
|
|
|
|
|
source_type: str,
|
|
|
|
|
raw_text: str | None = None,
|
|
|
|
|
conversation_id: str | None = None,
|
|
|
|
|
captured_at: datetime | None = None,
|
|
|
|
|
) -> MemorySource:
|
|
|
|
|
"""Create a memory source (sync). Caller must commit."""
|
|
|
|
|
source = MemorySource(
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
source_type=source_type,
|
|
|
|
|
raw_text=raw_text,
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
captured_at=captured_at or datetime.now(timezone.utc),
|
|
|
|
|
)
|
|
|
|
|
session.add(source)
|
|
|
|
|
return source
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_source(
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
*,
|
|
|
|
|
user_id: str,
|
|
|
|
|
source_type: str,
|
|
|
|
|
raw_text: str | None = None,
|
|
|
|
|
conversation_id: str | None = None,
|
|
|
|
|
captured_at: datetime | None = None,
|
|
|
|
|
) -> MemorySource:
|
|
|
|
|
"""Create a memory source. Caller must commit."""
|
|
|
|
|
source = MemorySource(
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
source_type=source_type,
|
|
|
|
|
raw_text=raw_text,
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
captured_at=captured_at or datetime.now(timezone.utc),
|
|
|
|
|
)
|
|
|
|
|
db.add(source)
|
|
|
|
|
return source
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_chunk_sync(
|
|
|
|
|
session: Session,
|
|
|
|
|
*,
|
|
|
|
|
source_id: str,
|
|
|
|
|
user_id: str,
|
|
|
|
|
content: str,
|
|
|
|
|
chunk_index: int,
|
|
|
|
|
) -> MemoryChunk:
|
|
|
|
|
"""Create a memory chunk (sync). Caller must commit."""
|
|
|
|
|
chunk = MemoryChunk(
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
source_id=source_id,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
content=content,
|
|
|
|
|
chunk_index=chunk_index,
|
|
|
|
|
)
|
|
|
|
|
session.add(chunk)
|
|
|
|
|
return chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_chunk(
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
*,
|
|
|
|
|
source_id: str,
|
|
|
|
|
user_id: str,
|
|
|
|
|
content: str,
|
|
|
|
|
chunk_index: int,
|
|
|
|
|
) -> MemoryChunk:
|
|
|
|
|
"""Create a memory chunk. Caller must commit."""
|
|
|
|
|
chunk = MemoryChunk(
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
source_id=source_id,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
content=content,
|
|
|
|
|
chunk_index=chunk_index,
|
|
|
|
|
)
|
|
|
|
|
db.add(chunk)
|
|
|
|
|
return chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_chunk_fts_sync(session: Session, chunk_id: str) -> None:
|
|
|
|
|
"""Populate content_tsv for FTS (sync). Caller must commit."""
|
|
|
|
|
session.execute(
|
|
|
|
|
text(
|
|
|
|
|
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
|
|
|
|
|
),
|
|
|
|
|
{"id": chunk_id},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def update_chunk_embedding(
|
|
|
|
|
db: AsyncSession, chunk_id: str, embedding: list[float]
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Update chunk embedding. Caller must commit."""
|
|
|
|
|
chunk = await db.get(MemoryChunk, chunk_id)
|
|
|
|
|
if chunk:
|
|
|
|
|
chunk.embedding = embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def update_chunk_fts(db: AsyncSession, chunk_id: str) -> None:
|
|
|
|
|
"""Populate content_tsv for FTS. Caller must commit."""
|
|
|
|
|
await db.execute(
|
|
|
|
|
text(
|
|
|
|
|
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
|
|
|
|
|
),
|
|
|
|
|
{"id": chunk_id},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def search_chunks_fts(
|
|
|
|
|
db: AsyncSession, user_id: str, query: str, limit: int = 20
|
|
|
|
|
) -> list[dict]:
|
|
|
|
|
"""FTS search on memory_chunks. Returns list of {id, content, chunk_index}."""
|
|
|
|
|
if not query or not query.strip():
|
|
|
|
|
return []
|
|
|
|
|
q = query.strip()
|
|
|
|
|
stmt = text("""
|
|
|
|
|
SELECT id, content, chunk_index
|
|
|
|
|
FROM memory_chunks
|
|
|
|
|
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
|
|
|
|
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
|
|
|
|
|
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
|
|
|
|
|
LIMIT :lim
|
|
|
|
|
""")
|
|
|
|
|
result = await db.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
|
|
|
|
|
rows = result.mappings().all()
|
|
|
|
|
return [
|
|
|
|
|
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
|
|
|
|
|
for r in rows
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_chunks_by_ids(
|
|
|
|
|
db: AsyncSession, chunk_ids: list[str]
|
|
|
|
|
) -> list[MemoryChunk]:
|
|
|
|
|
"""Fetch chunks by IDs."""
|
|
|
|
|
if not chunk_ids:
|
|
|
|
|
return []
|
|
|
|
|
stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids))
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
chunks = list(result.unique().scalars().all())
|
|
|
|
|
order = {cid: i for i, cid in enumerate(chunk_ids)}
|
|
|
|
|
return sorted(chunks, key=lambda c: order.get(c.id, 999))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_facts_for_user(
|
|
|
|
|
db: AsyncSession, user_id: str, limit: int = 20
|
|
|
|
|
) -> list[MemoryFact]:
|
|
|
|
|
"""Fetch recent facts for user."""
|
|
|
|
|
stmt = (
|
|
|
|
|
select(MemoryFact)
|
|
|
|
|
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
|
|
|
|
|
.order_by(MemoryFact.created_at.desc())
|
|
|
|
|
.limit(limit)
|
|
|
|
|
)
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
return list(result.unique().scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def search_chunks_vector(
|
|
|
|
|
db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20
|
|
|
|
|
) -> list[dict]:
|
|
|
|
|
"""Vector similarity search. Returns list of {id, content, chunk_index, distance}."""
|
|
|
|
|
if not query_embedding:
|
|
|
|
|
return []
|
|
|
|
|
# pgvector cosine distance: 1 - cosine_similarity, lower is better
|
|
|
|
|
stmt = text("""
|
|
|
|
|
SELECT id, content, chunk_index,
|
|
|
|
|
(embedding <=> :emb::vector) AS distance
|
|
|
|
|
FROM memory_chunks
|
|
|
|
|
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
|
|
|
|
AND embedding IS NOT NULL
|
|
|
|
|
ORDER BY embedding <=> :emb2::vector
|
|
|
|
|
LIMIT :lim
|
|
|
|
|
""")
|
|
|
|
|
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
stmt,
|
|
|
|
|
{"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit},
|
|
|
|
|
)
|
|
|
|
|
rows = result.mappings().all()
|
|
|
|
|
return [
|
|
|
|
|
{
|
|
|
|
|
"id": r["id"],
|
|
|
|
|
"content": r["content"],
|
|
|
|
|
"chunk_index": r["chunk_index"],
|
|
|
|
|
"distance": float(r["distance"]),
|
|
|
|
|
}
|
|
|
|
|
for r in rows
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_evidence_sync(
|
|
|
|
|
session: Session, user_id: str, query: str, *, top_k: int = 10
|
|
|
|
|
) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
Sync evidence retrieval for Celery tasks.
|
|
|
|
|
FTS only (no vector), returns evidence bundle.
|
|
|
|
|
"""
|
|
|
|
|
if not query or not query.strip():
|
|
|
|
|
return {
|
|
|
|
|
"relevant_chunks": [],
|
|
|
|
|
"relevant_summaries": [],
|
|
|
|
|
"relevant_facts": [],
|
|
|
|
|
"timeline_hints": [],
|
|
|
|
|
"relevant_stories": [],
|
|
|
|
|
}
|
|
|
|
|
q = query.strip()
|
|
|
|
|
# FTS chunks
|
|
|
|
|
stmt = text("""
|
|
|
|
|
SELECT id, content, chunk_index
|
|
|
|
|
FROM memory_chunks
|
|
|
|
|
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
|
|
|
|
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
|
|
|
|
|
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
|
|
|
|
|
LIMIT :lim
|
|
|
|
|
""")
|
|
|
|
|
result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": top_k})
|
|
|
|
|
rows = result.mappings().all()
|
|
|
|
|
relevant_chunks = [
|
|
|
|
|
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
|
|
|
|
|
for r in rows
|
|
|
|
|
]
|
|
|
|
|
# Facts
|
|
|
|
|
facts_stmt = (
|
|
|
|
|
select(MemoryFact)
|
|
|
|
|
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
|
|
|
|
|
.order_by(MemoryFact.created_at.desc())
|
|
|
|
|
.limit(top_k)
|
|
|
|
|
)
|
|
|
|
|
facts = list(session.execute(facts_stmt).unique().scalars().all())
|
|
|
|
|
relevant_facts = [
|
|
|
|
|
{
|
|
|
|
|
"id": f.id,
|
|
|
|
|
"fact_type": f.fact_type,
|
|
|
|
|
"subject": f.subject,
|
|
|
|
|
"predicate": f.predicate,
|
|
|
|
|
"object_json": f.object_json,
|
|
|
|
|
}
|
|
|
|
|
for f in facts
|
|
|
|
|
]
|
|
|
|
|
# Timeline
|
|
|
|
|
events_stmt = (
|
|
|
|
|
select(TimelineEvent)
|
|
|
|
|
.where(TimelineEvent.user_id == user_id)
|
|
|
|
|
.order_by(TimelineEvent.event_year.desc().nullslast())
|
|
|
|
|
.limit(top_k)
|
|
|
|
|
)
|
|
|
|
|
events = list(session.execute(events_stmt).unique().scalars().all())
|
|
|
|
|
timeline_hints = [
|
|
|
|
|
{
|
|
|
|
|
"id": e.id,
|
|
|
|
|
"event_year": e.event_year,
|
|
|
|
|
"event_date": e.event_date,
|
|
|
|
|
"title": e.title,
|
|
|
|
|
"description": e.description,
|
|
|
|
|
}
|
|
|
|
|
for e in events
|
|
|
|
|
]
|
|
|
|
|
return {
|
|
|
|
|
"relevant_chunks": relevant_chunks,
|
|
|
|
|
"relevant_summaries": [],
|
|
|
|
|
"relevant_facts": relevant_facts,
|
|
|
|
|
"timeline_hints": timeline_hints,
|
|
|
|
|
"relevant_stories": [],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_timeline_events_for_user(
|
|
|
|
|
db: AsyncSession, user_id: str, limit: int = 20
|
|
|
|
|
) -> list[TimelineEvent]:
|
|
|
|
|
"""Fetch timeline events for user."""
|
|
|
|
|
stmt = (
|
|
|
|
|
select(TimelineEvent)
|
|
|
|
|
.where(TimelineEvent.user_id == user_id)
|
|
|
|
|
.order_by(
|
|
|
|
|
TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc()
|
|
|
|
|
)
|
|
|
|
|
.limit(limit)
|
|
|
|
|
)
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
return list(result.unique().scalars().all())
|