"""Memory repository — MemorySource, MemoryChunk, and MemoryFact data access.""" import uuid from datetime import datetime, timedelta, timezone from sqlalchemy import cast, literal, or_, select, text, tuple_, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.types import String as SqlString from app.features.memory.models import ( MemoryChunk, MemoryCurationAction, MemoryFact, MemorySource, MemorySummary, ) def _new_id() -> str: return str(uuid.uuid4()) async def create_source( db: AsyncSession, *, user_id: str, source_type: str, raw_text: str | None = None, conversation_id: str | None = None, captured_at: datetime | None = None, lineage_json: dict | None = None, primary_user_message_id: str | None = None, ) -> MemorySource: """Create a memory source. Caller must commit.""" source = MemorySource( id=_new_id(), user_id=user_id, source_type=source_type, raw_text=raw_text, embedding_status="pending", enrichment_status="pending", conversation_id=conversation_id, lineage_json=lineage_json, primary_user_message_id=primary_user_message_id, captured_at=captured_at or datetime.now(timezone.utc), ) db.add(source) return source async def create_chunk( db: AsyncSession, *, source_id: str, user_id: str, content: str, chunk_index: int, ) -> MemoryChunk: """Create a memory chunk. Caller must commit.""" chunk = MemoryChunk( id=_new_id(), source_id=source_id, user_id=user_id, content=content, chunk_index=chunk_index, embedding_status="pending", ) db.add(chunk) return chunk async def update_chunk_embedding( db: AsyncSession, chunk_id: str, embedding: list[float] ) -> None: """Update chunk embedding. Caller must commit.""" chunk = await db.get(MemoryChunk, chunk_id) if chunk: chunk.embedding = embedding chunk.embedding_status = "success" chunk.embedding_error = None async def set_chunk_embedding_status( db: AsyncSession, chunk_id: str, *, status: str, error: str | None = None, ) -> bool: chunk = await db.get(MemoryChunk, chunk_id) if chunk is None: return False chunk.embedding_status = status chunk.embedding_error = error return True async def set_source_embedding_status( db: AsyncSession, *, source_id: str, user_id: str, status: str, error: str | None = None, ) -> bool: source = await db.get(MemorySource, source_id) if source is None or source.user_id != user_id: return False source.embedding_status = status source.embedding_error = error return True async def set_source_enrichment_status( db: AsyncSession, *, source_id: str, user_id: str, status: str, error: str | None = None, ) -> bool: source = await db.get(MemorySource, source_id) if source is None or source.user_id != user_id: return False source.enrichment_status = status source.enrichment_error = error return True async def list_chunks_for_source( db: AsyncSession, *, user_id: str, source_id: str, include_excluded: bool = True, ) -> list[MemoryChunk]: stmt = ( select(MemoryChunk) .where(MemoryChunk.user_id == user_id, MemoryChunk.source_id == source_id) .order_by(MemoryChunk.chunk_index.asc(), MemoryChunk.id.asc()) ) if not include_excluded: stmt = stmt.where( or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None)) ) result = await db.execute(stmt) return list(result.unique().scalars().all()) async def get_chunks_by_ids( db: AsyncSession, chunk_ids: list[str] ) -> list[MemoryChunk]: """Fetch chunks by IDs.""" if not chunk_ids: return [] stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids)) result = await db.execute(stmt) chunks = list(result.unique().scalars().all()) order = {cid: i for i, cid in enumerate(chunk_ids)} return sorted(chunks, key=lambda c: order.get(c.id, 999)) async def get_facts_for_user( db: AsyncSession, user_id: str, limit: int = 20 ) -> list[MemoryFact]: """Fetch recent facts for user.""" stmt = ( select(MemoryFact) .where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed") .order_by(MemoryFact.created_at.desc()) .limit(limit) ) result = await db.execute(stmt) return list(result.unique().scalars().all()) async def search_facts_for_user_async( db: AsyncSession, user_id: str, query: str, limit: int = 20 ) -> list[MemoryFact]: q = (query or "").strip() if not q: return [] pat = f"%{q}%" stmt = ( select(MemoryFact) .where( MemoryFact.user_id == user_id, MemoryFact.status == "confirmed", or_( MemoryFact.subject.ilike(pat), MemoryFact.predicate.ilike(pat), cast(MemoryFact.object_json, SqlString).ilike(pat), ), ) .order_by(MemoryFact.created_at.desc()) .limit(limit) ) result = await db.execute(stmt) return list(result.unique().scalars().all()) async def mark_facts_stale_for_excluded_chunk( db: AsyncSession, *, user_id: str, chunk_id: str ) -> int: stmt = ( update(MemoryFact) .where( MemoryFact.user_id == user_id, MemoryFact.source_chunk_id == chunk_id, MemoryFact.status.in_(["confirmed", "candidate"]), ) .values(status="stale") ) res = await db.execute(stmt) return int(res.rowcount or 0) async def search_chunks_vector( db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20 ) -> list[dict]: """Vector similarity search. Returns list of {id, content, chunk_index, distance}.""" if not query_embedding: return [] # pgvector cosine distance: 1 - cosine_similarity, lower is better stmt = text(""" SELECT id, content, chunk_index, (embedding <=> CAST(:emb AS vector)) AS distance FROM memory_chunks WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false) AND embedding IS NOT NULL ORDER BY embedding <=> CAST(:emb2 AS vector) LIMIT :lim """) emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]" result = await db.execute( stmt, {"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit}, ) rows = result.mappings().all() return [ { "id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"], "distance": float(r["distance"]), } for r in rows ] async def list_users_with_recent_chunks(db: AsyncSession, *, hours: int) -> list[str]: """最近 N 小时内有新 chunk 的用户 id(Beat compaction 扫描)。""" if hours < 1: hours = 1 cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) stmt = ( select(MemoryChunk.user_id).where(MemoryChunk.created_at >= cutoff).distinct() ) result = await db.execute(stmt) return list(result.scalars().all()) async def list_storage_keys_for_conversation( db: AsyncSession, conversation_id: str ) -> list[str]: """对话关联的 memory_sources 上记录的 COS object key(若有)。""" stmt = select(MemorySource.storage_key).where( MemorySource.conversation_id == conversation_id, MemorySource.storage_key.isnot(None), ) result = await db.execute(stmt) return sorted({r for r in result.scalars().all() if r}) async def create_memory_summary( db: AsyncSession, *, user_id: str, summary_type: str, content: str, source_chunk_ids: list[str] | None = None, ) -> MemorySummary: row = MemorySummary( id=_new_id(), user_id=user_id, summary_type=summary_type, content=content, source_chunk_ids=source_chunk_ids, ) db.add(row) return row async def create_memory_fact( db: AsyncSession, *, user_id: str, fact_type: str, subject: str | None, predicate: str | None, object_json: dict | None, confidence: float, source_chunk_id: str | None, status: str = "confirmed", lineage_json: dict | None = None, ) -> MemoryFact: row = MemoryFact( id=_new_id(), user_id=user_id, fact_type=fact_type, subject=subject, predicate=predicate, object_json=object_json, confidence=confidence, source_chunk_id=source_chunk_id, status=status, lineage_json=lineage_json, ) db.add(row) return row async def get_memory_fact_for_user( db: AsyncSession, fact_id: str, user_id: str ) -> MemoryFact | None: row = await db.get(MemoryFact, fact_id) if row is None or row.user_id != user_id: return None return row async def set_memory_fact_status( db: AsyncSession, fact_id: str, user_id: str, status: str ) -> bool: row = await get_memory_fact_for_user(db, fact_id, user_id) if row is None: return False row.status = status return True async def create_curation_action( db: AsyncSession, *, user_id: str, action_type: str, target_type: str, target_id: str, details: dict | None = None, ) -> MemoryCurationAction: row = MemoryCurationAction( id=_new_id(), user_id=user_id, action_type=action_type, target_type=target_type, target_id=target_id, details=details, ) db.add(row) return row async def get_memory_chunk_for_user( db: AsyncSession, chunk_id: str, user_id: str ) -> MemoryChunk | None: row = await db.get(MemoryChunk, chunk_id) if row is None or row.user_id != user_id: return None return row async def list_incremental_chunks_for_compaction( db: AsyncSession, *, user_id: str, after_cursor_ts: datetime, after_chunk_id: str, limit: int, candidate_chunk_ids: list[str] | None = None, candidate_source_ids: list[str] | None = None, ) -> list[MemoryChunk]: stmt = ( select(MemoryChunk) .where( MemoryChunk.user_id == user_id, tuple_(MemoryChunk.created_at, MemoryChunk.id) > tuple_(literal(after_cursor_ts), literal(after_chunk_id)), or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None)), ) .order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc()) .limit(limit) ) if candidate_chunk_ids: stmt = stmt.where(MemoryChunk.id.in_(candidate_chunk_ids)) if candidate_source_ids: stmt = stmt.where(MemoryChunk.source_id.in_(candidate_source_ids)) result = await db.execute(stmt) return list(result.unique().scalars().all()) async def get_first_chunk_after_cursor( db: AsyncSession, *, user_id: str, after_cursor_ts: datetime, after_chunk_id: str, ) -> MemoryChunk | None: stmt = ( select(MemoryChunk) .where( MemoryChunk.user_id == user_id, tuple_(MemoryChunk.created_at, MemoryChunk.id) > tuple_(literal(after_cursor_ts), literal(after_chunk_id)), ) .order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc()) .limit(1) ) result = await db.execute(stmt) return result.scalars().first() async def search_nearest_chunks_for_compaction( db: AsyncSession, *, user_id: str, chunk_id: str, query_embedding: list[float], limit: int, ) -> list[dict]: if not query_embedding: return [] stmt = text(""" SELECT mc.id, mc.content, mc.source_id, mc.event_year, mc.metadata_json, ms.source_type, mc.created_at, (mc.embedding <=> CAST(:emb AS vector)) AS distance FROM memory_chunks mc JOIN memory_sources ms ON ms.id = mc.source_id WHERE mc.user_id = :user_id AND (mc.is_excluded IS NOT TRUE OR mc.is_excluded = false) AND mc.embedding IS NOT NULL AND mc.id != :chunk_id ORDER BY mc.embedding <=> CAST(:emb2 AS vector) LIMIT :lim """) emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]" result = await db.execute( stmt, { "user_id": user_id, "chunk_id": chunk_id, "emb": emb_str, "emb2": emb_str, "lim": limit, }, ) return [ { "id": r["id"], "content": r["content"], "source_id": r["source_id"], "event_year": r["event_year"], "metadata_json": r["metadata_json"], "source_type": r["source_type"], "created_at": r["created_at"], "distance": float(r["distance"]), } for r in result.mappings().all() ] async def set_chunk_excluded( db: AsyncSession, chunk_id: str, user_id: str, excluded: bool ) -> bool: row = await get_memory_chunk_for_user(db, chunk_id, user_id) if row is None: return False row.is_excluded = excluded return True async def list_summaries_for_evidence_async( db: AsyncSession, *, user_id: str, q: str, limit: int ) -> list[dict]: if not (q or "").strip(): return [] pat = f"%{q.strip()}%" stmt = ( select(MemorySummary) .where( MemorySummary.user_id == user_id, MemorySummary.summary_type == "session", MemorySummary.content.ilike(pat), ) .order_by(MemorySummary.updated_at.desc()) .limit(limit) ) result = await db.execute(stmt) rows = list(result.unique().scalars().all()) return [ { "id": s.id, "summary_type": s.summary_type, "content": s.content, "source_chunk_ids": s.source_chunk_ids, } for s in rows[:limit] ]