""" MemoryService — conversation / memoir 的统一门面。 - ingest_transcript: transcript -> memory_sources, chunks, embedding, FTS - ingest 后可选:LLM 富化(session/rolling 摘要、事实、时间线) - retrieve: 委托 HybridRetriever 返回 evidence bundle(FTS + 可选向量 RRF) Celery 侧使用 `ingest_transcript_sync` + `retrieve_evidence_sync`,与异步路径差异见 `api/docs/memory-retrieval.md`。 """ import asyncio from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.features.memory.chunker import chunk_transcript from app.features.memory.schemas import EvidenceBundle from app.features.memory.repo import ( create_chunk, create_curation_action, create_source, set_chunk_excluded, set_memory_fact_status, update_chunk_embedding, update_chunk_fts, ) from app.ports.embedding import EmbeddingProvider logger = get_logger(__name__) class MemoryService: def __init__( self, db: AsyncSession, *, embedding_provider: EmbeddingProvider | None = None, ): self._db = db self._embedding = embedding_provider async def ingest_transcript( self, user_id: str, conversation_id: str, transcript: str ) -> str: """ Ingest conversation transcript into memory. Creates MemorySource, chunks, populates embedding + FTS. Returns source_id. """ if not transcript or not transcript.strip(): raise ValueError("transcript cannot be empty") source = await create_source( self._db, user_id=user_id, source_type="transcript", raw_text=transcript.strip(), conversation_id=conversation_id, ) chunks_text = chunk_transcript(transcript.strip()) chunk_records = [] for i, content in enumerate(chunks_text): chunk = await create_chunk( self._db, source_id=source.id, user_id=user_id, content=content, chunk_index=i, ) chunk_records.append((chunk.id, content)) await self._db.flush() # FTS: populate content_tsv for chunk_id, _ in chunk_records: await update_chunk_fts(self._db, chunk_id) # Embedding: 若有 provider 则写入 if self._embedding and chunk_records: texts = [c for _, c in chunk_records] embeddings = await self._embedding.embed_texts(texts) for (chunk_id, _), emb in zip(chunk_records, embeddings): if emb: await update_chunk_embedding(self._db, chunk_id, emb) try: from app.core.config import settings from app.core.dependencies import get_llm_provider_fast from app.features.memory.enrichment import enrich_memory_after_ingest_async if settings.memory_enrichment_enabled: llm = get_llm_provider_fast().langchain_llm await enrich_memory_after_ingest_async( self._db, user_id, source.id, llm ) except Exception as e: logger.warning( "memory enrichment 跳过: {} exc_type={}", e, type(e).__name__ ) await self._db.commit() return source.id async def retrieve( self, user_id: str, query: str, *, top_k: int = 10 ) -> EvidenceBundle: """Retrieve relevant evidence. 委托 HybridRetriever。""" from app.features.memory.retriever import HybridRetriever retriever = HybridRetriever(self._db, embedding_provider=self._embedding) raw = await retriever.retrieve(user_id=user_id, query=query, top_k=top_k) return EvidenceBundle.model_validate(raw) async def exclude_chunk( self, user_id: str, chunk_id: str, *, reason: str = "" ) -> bool: ok = await set_chunk_excluded(self._db, chunk_id, user_id, True) if not ok: return False await create_curation_action( self._db, user_id=user_id, action_type="exclude", target_type="chunk", target_id=chunk_id, details={"reason": reason} if reason else None, ) await self._db.commit() return True async def restore_chunk(self, user_id: str, chunk_id: str) -> bool: ok = await set_chunk_excluded(self._db, chunk_id, user_id, False) if not ok: return False await create_curation_action( self._db, user_id=user_id, action_type="restore", target_type="chunk", target_id=chunk_id, details=None, ) await self._db.commit() return True async def confirm_fact(self, user_id: str, fact_id: str) -> bool: ok = await set_memory_fact_status(self._db, fact_id, user_id, "confirmed") if not ok: return False await create_curation_action( self._db, user_id=user_id, action_type="confirm", target_type="fact", target_id=fact_id, details=None, ) await self._db.commit() return True async def reject_fact( self, user_id: str, fact_id: str, *, reason: str = "" ) -> bool: ok = await set_memory_fact_status(self._db, fact_id, user_id, "rejected") if not ok: return False await create_curation_action( self._db, user_id=user_id, action_type="reject", target_type="fact", target_id=fact_id, details={"reason": reason} if reason else None, ) await self._db.commit() return ok def ingest_transcript_sync( session, user_id: str, conversation_id: str, transcript: str, ) -> str: """ Sync transcript ingest for Celery tasks. Creates source + chunks + FTS, and best-effort populates embeddings. Returns source_id. """ from app.core.dependencies import get_embedding_provider from app.features.memory.chunker import chunk_transcript from app.features.memory.repo import ( create_chunk_sync, create_source_sync, update_chunk_embedding_sync, update_chunk_fts_sync, ) if not transcript or not transcript.strip(): raise ValueError("transcript cannot be empty") source = create_source_sync( session, user_id=user_id, source_type="transcript", raw_text=transcript.strip(), conversation_id=conversation_id, ) session.flush() chunks_text = chunk_transcript(transcript.strip()) chunk_records: list[tuple[str, str]] = [] for i, content in enumerate(chunks_text): chunk = create_chunk_sync( session, source_id=source.id, user_id=user_id, content=content, chunk_index=i, ) session.flush() chunk_records.append((chunk.id, content)) update_chunk_fts_sync(session, chunk.id) try: embedding_provider = get_embedding_provider() if chunk_records and embedding_provider is not None: texts = [content for _, content in chunk_records] embeddings = asyncio.run(embedding_provider.embed_texts(texts)) for (chunk_id, _), emb in zip(chunk_records, embeddings): if emb: update_chunk_embedding_sync(session, chunk_id, emb) except Exception as e: logger.warning( "memory embedding 跳过(sync): {} exc_type={}", e, type(e).__name__ ) try: from app.core.config import settings from app.features.memory.enrichment import enrich_memory_after_ingest_sync if settings.memory_enrichment_enabled: enrich_memory_after_ingest_sync(session, user_id, source.id, llm=None) except Exception as e: logger.warning( "memory enrichment 跳过(sync): {} exc_type={}", e, type(e).__name__ ) session.commit() return source.id