"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access.""" import uuid from datetime import datetime, timezone from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.features.memory.models import ( MemoryChunk, MemoryFact, MemorySource, TimelineEvent, ) def _new_id() -> str: return str(uuid.uuid4()) def create_source_sync( session: Session, *, user_id: str, source_type: str, raw_text: str | None = None, conversation_id: str | None = None, captured_at: datetime | None = None, ) -> MemorySource: """Create a memory source (sync). Caller must commit.""" source = MemorySource( id=_new_id(), user_id=user_id, source_type=source_type, raw_text=raw_text, conversation_id=conversation_id, captured_at=captured_at or datetime.now(timezone.utc), ) session.add(source) return source async def create_source( db: AsyncSession, *, user_id: str, source_type: str, raw_text: str | None = None, conversation_id: str | None = None, captured_at: datetime | None = None, ) -> MemorySource: """Create a memory source. Caller must commit.""" source = MemorySource( id=_new_id(), user_id=user_id, source_type=source_type, raw_text=raw_text, conversation_id=conversation_id, captured_at=captured_at or datetime.now(timezone.utc), ) db.add(source) return source def create_chunk_sync( session: Session, *, source_id: str, user_id: str, content: str, chunk_index: int, ) -> MemoryChunk: """Create a memory chunk (sync). Caller must commit.""" chunk = MemoryChunk( id=_new_id(), source_id=source_id, user_id=user_id, content=content, chunk_index=chunk_index, ) session.add(chunk) return chunk async def create_chunk( db: AsyncSession, *, source_id: str, user_id: str, content: str, chunk_index: int, ) -> MemoryChunk: """Create a memory chunk. Caller must commit.""" chunk = MemoryChunk( id=_new_id(), source_id=source_id, user_id=user_id, content=content, chunk_index=chunk_index, ) db.add(chunk) return chunk def update_chunk_fts_sync(session: Session, chunk_id: str) -> None: """Populate content_tsv for FTS (sync). Caller must commit.""" session.execute( text( "UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id" ), {"id": chunk_id}, ) async def update_chunk_embedding( db: AsyncSession, chunk_id: str, embedding: list[float] ) -> None: """Update chunk embedding. Caller must commit.""" chunk = await db.get(MemoryChunk, chunk_id) if chunk: chunk.embedding = embedding async def update_chunk_fts(db: AsyncSession, chunk_id: str) -> None: """Populate content_tsv for FTS. Caller must commit.""" await db.execute( text( "UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id" ), {"id": chunk_id}, ) async def search_chunks_fts( db: AsyncSession, user_id: str, query: str, limit: int = 20 ) -> list[dict]: """FTS search on memory_chunks. Returns list of {id, content, chunk_index}.""" if not query or not query.strip(): return [] q = query.strip() stmt = text(""" SELECT id, content, chunk_index FROM memory_chunks WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false) AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q) ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC LIMIT :lim """) result = await db.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit}) rows = result.mappings().all() return [ {"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]} for r in rows ] async def get_chunks_by_ids( db: AsyncSession, chunk_ids: list[str] ) -> list[MemoryChunk]: """Fetch chunks by IDs.""" if not chunk_ids: return [] stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids)) result = await db.execute(stmt) chunks = list(result.unique().scalars().all()) order = {cid: i for i, cid in enumerate(chunk_ids)} return sorted(chunks, key=lambda c: order.get(c.id, 999)) async def get_facts_for_user( db: AsyncSession, user_id: str, limit: int = 20 ) -> list[MemoryFact]: """Fetch recent facts for user.""" stmt = ( select(MemoryFact) .where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed") .order_by(MemoryFact.created_at.desc()) .limit(limit) ) result = await db.execute(stmt) return list(result.unique().scalars().all()) async def search_chunks_vector( db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20 ) -> list[dict]: """Vector similarity search. Returns list of {id, content, chunk_index, distance}.""" if not query_embedding: return [] # pgvector cosine distance: 1 - cosine_similarity, lower is better stmt = text(""" SELECT id, content, chunk_index, (embedding <=> :emb::vector) AS distance FROM memory_chunks WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false) AND embedding IS NOT NULL ORDER BY embedding <=> :emb2::vector LIMIT :lim """) emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]" result = await db.execute( stmt, {"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit}, ) rows = result.mappings().all() return [ { "id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"], "distance": float(r["distance"]), } for r in rows ] def retrieve_evidence_sync( session: Session, user_id: str, query: str, *, top_k: int = 10 ) -> dict: """ Sync evidence retrieval for Celery tasks. 能力:**仅 FTS** 检索 chunks(与 `HybridRetriever` 的 FTS+向量 RRF 不同,见 `api/docs/memory-retrieval.md`);confirmed facts;timeline。 ingest 在 Celery 任务内先于 story 流水线执行并已 commit,故本检索可见刚写入的 chunk。 """ if not query or not query.strip(): return { "relevant_chunks": [], "relevant_summaries": [], "relevant_facts": [], "timeline_hints": [], "relevant_stories": [], } q = query.strip() # FTS chunks stmt = text(""" SELECT id, content, chunk_index FROM memory_chunks WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false) AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q) ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC LIMIT :lim """) result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": top_k}) rows = result.mappings().all() relevant_chunks = [ {"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]} for r in rows ] # Facts facts_stmt = ( select(MemoryFact) .where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed") .order_by(MemoryFact.created_at.desc()) .limit(top_k) ) facts = list(session.execute(facts_stmt).unique().scalars().all()) relevant_facts = [ { "id": f.id, "fact_type": f.fact_type, "subject": f.subject, "predicate": f.predicate, "object_json": f.object_json, } for f in facts ] # Timeline events_stmt = ( select(TimelineEvent) .where(TimelineEvent.user_id == user_id) .order_by(TimelineEvent.event_year.desc().nullslast()) .limit(top_k) ) events = list(session.execute(events_stmt).unique().scalars().all()) timeline_hints = [ { "id": e.id, "event_year": e.event_year, "event_date": e.event_date, "title": e.title, "description": e.description, } for e in events ] return { "relevant_chunks": relevant_chunks, "relevant_summaries": [], "relevant_facts": relevant_facts, "timeline_hints": timeline_hints, "relevant_stories": [], } async def get_timeline_events_for_user( db: AsyncSession, user_id: str, limit: int = 20 ) -> list[TimelineEvent]: """Fetch timeline events for user.""" stmt = ( select(TimelineEvent) .where(TimelineEvent.user_id == user_id) .order_by( TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc() ) .limit(limit) ) result = await db.execute(stmt) return list(result.unique().scalars().all()) async def list_storage_keys_for_conversation( db: AsyncSession, conversation_id: str ) -> list[str]: """对话关联的 memory_sources 上记录的 COS object key(若有)。""" stmt = select(MemorySource.storage_key).where( MemorySource.conversation_id == conversation_id, MemorySource.storage_key.isnot(None), ) result = await db.execute(stmt) return sorted({r for r in result.scalars().all() if r})