"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access.""" import uuid from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING from sqlalchemy import delete, literal, or_, select, text, tuple_, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.features.memory.models import ( MemoryChunk, MemoryCurationAction, MemoryFact, MemorySource, MemorySummary, TimelineEvent, ) if TYPE_CHECKING: from app.ports.embedding import EmbeddingProvider def _new_id() -> str: return str(uuid.uuid4()) def create_source_sync( session: Session, *, user_id: str, source_type: str, raw_text: str | None = None, conversation_id: str | None = None, captured_at: datetime | None = None, ) -> MemorySource: """Create a memory source (sync). Caller must commit.""" source = MemorySource( id=_new_id(), user_id=user_id, source_type=source_type, raw_text=raw_text, conversation_id=conversation_id, captured_at=captured_at or datetime.now(timezone.utc), ) session.add(source) return source async def create_source( db: AsyncSession, *, user_id: str, source_type: str, raw_text: str | None = None, conversation_id: str | None = None, captured_at: datetime | None = None, ) -> MemorySource: """Create a memory source. Caller must commit.""" source = MemorySource( id=_new_id(), user_id=user_id, source_type=source_type, raw_text=raw_text, conversation_id=conversation_id, captured_at=captured_at or datetime.now(timezone.utc), ) db.add(source) return source def create_chunk_sync( session: Session, *, source_id: str, user_id: str, content: str, chunk_index: int, ) -> MemoryChunk: """Create a memory chunk (sync). Caller must commit.""" chunk = MemoryChunk( id=_new_id(), source_id=source_id, user_id=user_id, content=content, chunk_index=chunk_index, ) session.add(chunk) return chunk async def create_chunk( db: AsyncSession, *, source_id: str, user_id: str, content: str, chunk_index: int, ) -> MemoryChunk: """Create a memory chunk. Caller must commit.""" chunk = MemoryChunk( id=_new_id(), source_id=source_id, user_id=user_id, content=content, chunk_index=chunk_index, ) db.add(chunk) return chunk def update_chunk_embedding_sync( session: Session, chunk_id: str, embedding: list[float] ) -> None: """Update chunk embedding (sync). Caller must commit.""" chunk = session.get(MemoryChunk, chunk_id) if chunk: chunk.embedding = embedding async def update_chunk_embedding( db: AsyncSession, chunk_id: str, embedding: list[float] ) -> None: """Update chunk embedding. Caller must commit.""" chunk = await db.get(MemoryChunk, chunk_id) if chunk: chunk.embedding = embedding async def get_chunks_by_ids( db: AsyncSession, chunk_ids: list[str] ) -> list[MemoryChunk]: """Fetch chunks by IDs.""" if not chunk_ids: return [] stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids)) result = await db.execute(stmt) chunks = list(result.unique().scalars().all()) order = {cid: i for i, cid in enumerate(chunk_ids)} return sorted(chunks, key=lambda c: order.get(c.id, 999)) async def get_facts_for_user( db: AsyncSession, user_id: str, limit: int = 20 ) -> list[MemoryFact]: """Fetch recent facts for user.""" stmt = ( select(MemoryFact) .where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed") .order_by(MemoryFact.created_at.desc()) .limit(limit) ) result = await db.execute(stmt) return list(result.unique().scalars().all()) def get_facts_for_user_sync( session: Session, user_id: str, limit: int = 20 ) -> list[MemoryFact]: stmt = ( select(MemoryFact) .where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed") .order_by(MemoryFact.created_at.desc()) .limit(limit) ) return list(session.execute(stmt).unique().scalars().all()) def get_timeline_events_for_user_sync( session: Session, user_id: str, limit: int = 20 ) -> list[TimelineEvent]: stmt = ( select(TimelineEvent) .where(TimelineEvent.user_id == user_id) .order_by( TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc() ) .limit(limit) ) return list(session.execute(stmt).unique().scalars().all()) def search_facts_for_user_sync( session: Session, user_id: str, query: str, limit: int = 20 ) -> list[MemoryFact]: from app.core.config import settings q = (query or "").strip() if not q: return get_facts_for_user_sync(session, user_id, limit) pat = f"%{q}%" stmt = ( select(MemoryFact) .where( MemoryFact.user_id == user_id, MemoryFact.status == "confirmed", or_(MemoryFact.subject.ilike(pat), MemoryFact.predicate.ilike(pat)), ) .order_by(MemoryFact.created_at.desc()) .limit(limit) ) rows = list(session.execute(stmt).unique().scalars().all()) if rows: return rows if settings.memory_fact_search_use_recent_fallback: return get_facts_for_user_sync(session, user_id, limit) return [] async def search_facts_for_user_async( db: AsyncSession, user_id: str, query: str, limit: int = 20 ) -> list[MemoryFact]: from app.core.config import settings q = (query or "").strip() if not q: return await get_facts_for_user(db, user_id=user_id, limit=limit) pat = f"%{q}%" stmt = ( select(MemoryFact) .where( MemoryFact.user_id == user_id, MemoryFact.status == "confirmed", or_(MemoryFact.subject.ilike(pat), MemoryFact.predicate.ilike(pat)), ) .order_by(MemoryFact.created_at.desc()) .limit(limit) ) result = await db.execute(stmt) rows = list(result.unique().scalars().all()) if rows: return rows if settings.memory_fact_search_use_recent_fallback: return await get_facts_for_user(db, user_id=user_id, limit=limit) return [] def mark_facts_stale_for_excluded_chunk_sync( session: Session, *, user_id: str, chunk_id: str ) -> int: """ Compaction 软排除 chunk 后:将与该 chunk 绑定的候选/已确认事实标为 stale, 避免 derive-only 断言在失去原文支撑后仍作为权威 evidence。 """ stmt = ( update(MemoryFact) .where( MemoryFact.user_id == user_id, MemoryFact.source_chunk_id == chunk_id, MemoryFact.status.in_(["confirmed", "candidate"]), ) .values(status="stale") ) res = session.execute(stmt) return int(res.rowcount or 0) def search_timeline_events_for_user_sync( session: Session, user_id: str, query: str, limit: int = 20 ) -> list[TimelineEvent]: q = (query or "").strip() if not q: return get_timeline_events_for_user_sync(session, user_id, limit) pat = f"%{q}%" stmt = ( select(TimelineEvent) .where( TimelineEvent.user_id == user_id, or_( TimelineEvent.title.ilike(pat), TimelineEvent.description.ilike(pat), ), ) .order_by(TimelineEvent.event_year.desc().nullslast()) .limit(limit) ) rows = list(session.execute(stmt).unique().scalars().all()) if rows: return rows return get_timeline_events_for_user_sync(session, user_id, limit) async def search_timeline_events_for_user_async( db: AsyncSession, user_id: str, query: str, limit: int = 20 ) -> list[TimelineEvent]: q = (query or "").strip() if not q: return await get_timeline_events_for_user(db, user_id=user_id, limit=limit) pat = f"%{q}%" stmt = ( select(TimelineEvent) .where( TimelineEvent.user_id == user_id, or_( TimelineEvent.title.ilike(pat), TimelineEvent.description.ilike(pat), ), ) .order_by(TimelineEvent.event_year.desc().nullslast()) .limit(limit) ) result = await db.execute(stmt) rows = list(result.unique().scalars().all()) if rows: return rows return await get_timeline_events_for_user(db, user_id=user_id, limit=limit) async def search_chunks_vector( db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20 ) -> list[dict]: """Vector similarity search. Returns list of {id, content, chunk_index, distance}.""" if not query_embedding: return [] # pgvector cosine distance: 1 - cosine_similarity, lower is better stmt = text(""" SELECT id, content, chunk_index, (embedding <=> CAST(:emb AS vector)) AS distance FROM memory_chunks WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false) AND embedding IS NOT NULL ORDER BY embedding <=> CAST(:emb2 AS vector) LIMIT :lim """) emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]" result = await db.execute( stmt, {"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit}, ) rows = result.mappings().all() return [ { "id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"], "distance": float(r["distance"]), } for r in rows ] def search_chunks_vector_sync( session: Session, user_id: str, query_embedding: list[float], limit: int = 20 ) -> list[dict]: """pgvector 余弦距离检索(sync,Celery)。返回 {id, content, chunk_index, distance}。""" if not query_embedding: return [] stmt = text(""" SELECT id, content, chunk_index, (embedding <=> CAST(:emb AS vector)) AS distance FROM memory_chunks WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false) AND embedding IS NOT NULL ORDER BY embedding <=> CAST(:emb2 AS vector) LIMIT :lim """) emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]" result = session.execute( stmt, {"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit}, ) rows = result.mappings().all() return [ { "id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"], "distance": float(r["distance"]), } for r in rows ] def list_users_with_recent_chunks_sync(session: Session, *, hours: int) -> list[str]: """最近 N 小时内有新 chunk 的用户 id(Beat compaction 扫描)。""" if hours < 1: hours = 1 cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) stmt = ( select(MemoryChunk.user_id).where(MemoryChunk.created_at >= cutoff).distinct() ) return list(session.execute(stmt).scalars().all()) def list_summaries_for_evidence_sync( session: Session, *, user_id: str, q: str, limit: int ) -> list[dict]: """最新 rolling + 内容匹配 query 的摘要(ILIKE)。""" pat = f"%{q}%" rolling = ( session.execute( select(MemorySummary) .where( MemorySummary.user_id == user_id, MemorySummary.summary_type == "rolling", ) .order_by(MemorySummary.updated_at.desc()) .limit(1) ) .unique() .scalar_one_or_none() ) rows: list[MemorySummary] = [] seen: set[str] = set() if rolling: rows.append(rolling) seen.add(rolling.id) rest = limit - len(rows) if rest > 0: stmt = ( select(MemorySummary) .where( MemorySummary.user_id == user_id, MemorySummary.content.ilike(pat), ) .order_by(MemorySummary.updated_at.desc()) .limit(rest + len(seen)) ) for s in session.execute(stmt).unique().scalars().all(): if s.id not in seen: rows.append(s) seen.add(s.id) if len(rows) >= limit: break return [ { "id": s.id, "summary_type": s.summary_type, "content": s.content, "source_chunk_ids": s.source_chunk_ids, } for s in rows[:limit] ] def retrieve_evidence_sync( session: Session, user_id: str, query: str, *, top_k: int = 10, embedding_provider: "EmbeddingProvider | None" = None, ) -> dict: """ Sync evidence retrieval for Celery tasks. chunks:**向量**(pgvector)与异步 `HybridRetriever` 对齐;facts/timeline 按 query ILIKE。 """ from app.features.memory.evidence import retrieve_evidence_bundle_sync return retrieve_evidence_bundle_sync( session, user_id, query, top_k=top_k, embedding_provider=embedding_provider, ) async def get_timeline_events_for_user( db: AsyncSession, user_id: str, limit: int = 20 ) -> list[TimelineEvent]: """Fetch timeline events for user.""" stmt = ( select(TimelineEvent) .where(TimelineEvent.user_id == user_id) .order_by( TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc() ) .limit(limit) ) result = await db.execute(stmt) return list(result.unique().scalars().all()) async def list_storage_keys_for_conversation( db: AsyncSession, conversation_id: str ) -> list[str]: """对话关联的 memory_sources 上记录的 COS object key(若有)。""" stmt = select(MemorySource.storage_key).where( MemorySource.conversation_id == conversation_id, MemorySource.storage_key.isnot(None), ) result = await db.execute(stmt) return sorted({r for r in result.scalars().all() if r}) def list_chunks_for_source_sync(session: Session, source_id: str) -> list[MemoryChunk]: stmt = ( select(MemoryChunk) .where(MemoryChunk.source_id == source_id) .order_by(MemoryChunk.chunk_index.asc()) ) return list(session.execute(stmt).unique().scalars().all()) def create_memory_summary_sync( session: Session, *, user_id: str, summary_type: str, content: str, source_chunk_ids: list[str] | None = None, ) -> MemorySummary: row = MemorySummary( id=_new_id(), user_id=user_id, summary_type=summary_type, content=content, source_chunk_ids=source_chunk_ids, ) session.add(row) return row async def create_memory_summary( db: AsyncSession, *, user_id: str, summary_type: str, content: str, source_chunk_ids: list[str] | None = None, ) -> MemorySummary: row = MemorySummary( id=_new_id(), user_id=user_id, summary_type=summary_type, content=content, source_chunk_ids=source_chunk_ids, ) db.add(row) return row def get_latest_rolling_summary_sync( session: Session, user_id: str ) -> MemorySummary | None: stmt = ( select(MemorySummary) .where( MemorySummary.user_id == user_id, MemorySummary.summary_type == "rolling", ) .order_by(MemorySummary.updated_at.desc()) .limit(1) ) return session.execute(stmt).unique().scalar_one_or_none() def upsert_rolling_summary_sync( session: Session, *, user_id: str, content: str, source_chunk_ids: list[str] | None = None, ) -> MemorySummary: existing = get_latest_rolling_summary_sync(session, user_id) if existing: existing.content = content if source_chunk_ids is not None: existing.source_chunk_ids = source_chunk_ids return existing return create_memory_summary_sync( session, user_id=user_id, summary_type="rolling", content=content, source_chunk_ids=source_chunk_ids, ) def create_memory_fact_sync( session: Session, *, user_id: str, fact_type: str, subject: str | None, predicate: str | None, object_json: dict | None, confidence: float, source_chunk_id: str | None, status: str = "confirmed", ) -> MemoryFact: row = MemoryFact( id=_new_id(), user_id=user_id, fact_type=fact_type, subject=subject, predicate=predicate, object_json=object_json, confidence=confidence, source_chunk_id=source_chunk_id, status=status, ) session.add(row) return row async def create_memory_fact( db: AsyncSession, *, user_id: str, fact_type: str, subject: str | None, predicate: str | None, object_json: dict | None, confidence: float, source_chunk_id: str | None, status: str = "confirmed", ) -> MemoryFact: row = MemoryFact( id=_new_id(), user_id=user_id, fact_type=fact_type, subject=subject, predicate=predicate, object_json=object_json, confidence=confidence, source_chunk_id=source_chunk_id, status=status, ) db.add(row) return row async def get_memory_fact_for_user( db: AsyncSession, fact_id: str, user_id: str ) -> MemoryFact | None: row = await db.get(MemoryFact, fact_id) if row is None or row.user_id != user_id: return None return row async def set_memory_fact_status( db: AsyncSession, fact_id: str, user_id: str, status: str ) -> bool: row = await get_memory_fact_for_user(db, fact_id, user_id) if row is None: return False row.status = status return True def delete_timeline_events_by_memory_source_sync( session: Session, *, user_id: str, memory_source_id: str ) -> int: stmt = delete(TimelineEvent).where( TimelineEvent.user_id == user_id, TimelineEvent.memory_source_id == memory_source_id, ) result = session.execute(stmt) return result.rowcount or 0 async def delete_timeline_events_by_memory_source( db: AsyncSession, *, user_id: str, memory_source_id: str ) -> int: stmt = delete(TimelineEvent).where( TimelineEvent.user_id == user_id, TimelineEvent.memory_source_id == memory_source_id, ) result = await db.execute(stmt) return result.rowcount or 0 def create_timeline_event_sync( session: Session, *, user_id: str, event_year: int | None, event_date: str | None, title: str, description: str | None, person_refs: list | None = None, source_fact_ids: list[str] | None = None, memory_source_id: str | None = None, ) -> TimelineEvent: row = TimelineEvent( id=_new_id(), user_id=user_id, memory_source_id=memory_source_id, event_year=event_year, event_date=event_date, title=title, description=description, person_refs=person_refs, source_fact_ids=source_fact_ids, ) session.add(row) return row async def create_timeline_event( db: AsyncSession, *, user_id: str, event_year: int | None, event_date: str | None, title: str, description: str | None, person_refs: list | None = None, source_fact_ids: list[str] | None = None, memory_source_id: str | None = None, ) -> TimelineEvent: row = TimelineEvent( id=_new_id(), user_id=user_id, memory_source_id=memory_source_id, event_year=event_year, event_date=event_date, title=title, description=description, person_refs=person_refs, source_fact_ids=source_fact_ids, ) db.add(row) return row def create_curation_action_sync( session: Session, *, user_id: str, action_type: str, target_type: str, target_id: str, details: dict | None = None, ) -> MemoryCurationAction: row = MemoryCurationAction( id=_new_id(), user_id=user_id, action_type=action_type, target_type=target_type, target_id=target_id, details=details, ) session.add(row) return row async def create_curation_action( db: AsyncSession, *, user_id: str, action_type: str, target_type: str, target_id: str, details: dict | None = None, ) -> MemoryCurationAction: row = MemoryCurationAction( id=_new_id(), user_id=user_id, action_type=action_type, target_type=target_type, target_id=target_id, details=details, ) db.add(row) return row async def get_memory_chunk_for_user( db: AsyncSession, chunk_id: str, user_id: str ) -> MemoryChunk | None: row = await db.get(MemoryChunk, chunk_id) if row is None or row.user_id != user_id: return None return row def get_memory_chunk_sync( session: Session, chunk_id: str, user_id: str ) -> MemoryChunk | None: row = session.get(MemoryChunk, chunk_id) if row is None or row.user_id != user_id: return None return row def set_chunk_excluded_sync( session: Session, chunk_id: str, user_id: str, excluded: bool ) -> bool: row = get_memory_chunk_sync(session, chunk_id, user_id) if row is None: return False row.is_excluded = excluded return True def list_incremental_chunks_for_compaction_sync( session: Session, *, user_id: str, after_cursor_ts: datetime, after_chunk_id: str, limit: int, candidate_chunk_ids: list[str] | None = None, candidate_source_ids: list[str] | None = None, ) -> list[MemoryChunk]: """增量 chunk:(created_at, id) 字典序大于游标;可选与候选 id/source 求交。""" stmt = ( select(MemoryChunk) .where( MemoryChunk.user_id == user_id, tuple_(MemoryChunk.created_at, MemoryChunk.id) > tuple_(literal(after_cursor_ts), literal(after_chunk_id)), or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None)), ) .order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc()) .limit(limit) ) if candidate_chunk_ids: stmt = stmt.where(MemoryChunk.id.in_(candidate_chunk_ids)) if candidate_source_ids: stmt = stmt.where(MemoryChunk.source_id.in_(candidate_source_ids)) rows = session.execute(stmt).unique().scalars().all() return list(rows) def get_first_chunk_after_cursor_sync( session: Session, *, user_id: str, after_cursor_ts: datetime, after_chunk_id: str, ) -> MemoryChunk | None: """游标之后字典序第一条 chunk(含 excluded),用于空增量时推进游标。""" stmt = ( select(MemoryChunk) .where( MemoryChunk.user_id == user_id, tuple_(MemoryChunk.created_at, MemoryChunk.id) > tuple_(literal(after_cursor_ts), literal(after_chunk_id)), ) .order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc()) .limit(1) ) return session.execute(stmt).scalars().first() def search_nearest_chunks_for_compaction_sync( session: Session, *, user_id: str, chunk_id: str, query_embedding: list[float], limit: int, ) -> list[dict]: """ 按余弦距离取 Top-K 近邻(不含自身)。pgvector `<=>` 为 cosine distance。 返回 dict: id, content, source_id, event_year, metadata_json, source_type, distance, created_at """ if not query_embedding: return [] stmt = text(""" SELECT mc.id, mc.content, mc.source_id, mc.event_year, mc.metadata_json, ms.source_type, mc.created_at, (mc.embedding <=> CAST(:emb AS vector)) AS distance FROM memory_chunks mc JOIN memory_sources ms ON ms.id = mc.source_id WHERE mc.user_id = :user_id AND (mc.is_excluded IS NOT TRUE OR mc.is_excluded = false) AND mc.embedding IS NOT NULL AND mc.id != :chunk_id ORDER BY mc.embedding <=> CAST(:emb2 AS vector) LIMIT :lim """) emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]" result = session.execute( stmt, { "user_id": user_id, "chunk_id": chunk_id, "emb": emb_str, "emb2": emb_str, "lim": limit, }, ) return [ { "id": r["id"], "content": r["content"], "source_id": r["source_id"], "event_year": r["event_year"], "metadata_json": r["metadata_json"], "source_type": r["source_type"], "created_at": r["created_at"], "distance": float(r["distance"]), } for r in result.mappings().all() ] async def set_chunk_excluded( db: AsyncSession, chunk_id: str, user_id: str, excluded: bool ) -> bool: row = await get_memory_chunk_for_user(db, chunk_id, user_id) if row is None: return False row.is_excluded = excluded return True async def list_summaries_for_evidence_async( db: AsyncSession, *, user_id: str, q: str, limit: int ) -> list[dict]: if not (q or "").strip(): return [] pat = f"%{q.strip()}%" rolling_stmt = ( select(MemorySummary) .where( MemorySummary.user_id == user_id, MemorySummary.summary_type == "rolling", ) .order_by(MemorySummary.updated_at.desc()) .limit(1) ) r_result = await db.execute(rolling_stmt) rolling = r_result.unique().scalar_one_or_none() rows: list[MemorySummary] = [] seen: set[str] = set() if rolling: rows.append(rolling) seen.add(rolling.id) rest = limit - len(rows) if rest > 0: stmt = ( select(MemorySummary) .where( MemorySummary.user_id == user_id, MemorySummary.content.ilike(pat), ) .order_by(MemorySummary.updated_at.desc()) .limit(rest + len(seen)) ) o_result = await db.execute(stmt) for s in o_result.unique().scalars().all(): if s.id not in seen: rows.append(s) seen.add(s.id) if len(rows) >= limit: break return [ { "id": s.id, "summary_type": s.summary_type, "content": s.content, "source_chunk_ids": s.source_chunk_ids, } for s in rows[:limit] ]