"""Memory ingest service boundary.""" from __future__ import annotations from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.logging import get_logger from app.features.conversation.lineage_schemas import ( primary_user_message_id_from_lineage, ) from app.features.memory.chunker import chunk_transcript from app.features.memory.embedding_scheduler import ( MemoryEmbeddingRequest, MemoryEmbeddingScheduler, ) from app.features.memory.embedding_service import MemoryEmbeddingService from app.features.memory.enrichment_scheduler import ( MemoryEnrichmentRequest, MemoryEnrichmentScheduler, ) from app.features.memory.repo import ( create_chunk, create_source, ) from app.ports.embedding import EmbeddingProvider logger = get_logger(__name__) class MemoryIngestService: """Creates memory sources/chunks and schedules post-commit enrichment.""" def __init__( self, db: AsyncSession, *, embedding_provider: EmbeddingProvider | None = None, embedding_scheduler: MemoryEmbeddingScheduler | None = None, enrichment_scheduler: MemoryEnrichmentScheduler | None = None, ) -> None: self._db = db self._embedding = embedding_provider self._embedding_scheduler = embedding_scheduler or MemoryEmbeddingScheduler() self._enrichment_scheduler = enrichment_scheduler or MemoryEnrichmentScheduler() async def ingest_transcript( self, user_id: str, conversation_id: str, transcript: str, *, lineage_json: dict | None = None, ) -> str: if not transcript or not transcript.strip(): raise ValueError("transcript cannot be empty") primary_mid = ( primary_user_message_id_from_lineage(lineage_json) if lineage_json else None ) source = await create_source( self._db, user_id=user_id, source_type="transcript", raw_text=transcript.strip(), conversation_id=conversation_id, lineage_json=lineage_json, primary_user_message_id=primary_mid, ) chunk_records: list[tuple[str, str]] = [] for i, content in enumerate(chunk_transcript(transcript.strip())): chunk = await create_chunk( self._db, source_id=source.id, user_id=user_id, content=content, chunk_index=i, ) chunk_records.append((chunk.id, content)) await self._db.flush() await self._db.commit() embedding_result = await MemoryEmbeddingService( self._db, embedding_provider=self._embedding, ).embed_source(user_id, source.id) embedding_task_id = self._schedule_embedding_retry_if_needed( user_id, source.id, embedding_result, ) emb_ok = self._embedding.is_available() if self._embedding else False enrichment_task_id = self._enrichment_scheduler.schedule( MemoryEnrichmentRequest(user_id=user_id, source_id=source.id) ) logger.info( "event=memory_ingest_done user_id={} conversation_id={} source_id={} " "chunks={} vectors_written={} embedding_status={} embedding_available={} " "embedding_task_id={} enrichment_enabled={} enrichment_task_id={}", user_id, conversation_id, source.id, len(chunk_records), embedding_result.get("vectors_written", 0), embedding_result.get("status"), emb_ok, embedding_task_id, settings.memory_enrichment_enabled, enrichment_task_id, ) return source.id async def ingest_transcripts_batch( self, user_id: str, items: list[tuple[str, str, dict | None]], *, memoir_correlation_id: str | None = None, ) -> list[str]: """ Batch ingest transcript items through the async memory path. items: (conversation_id, transcript, lineage_json). Empty transcripts are skipped. """ source_ids: list[str] = [] chunk_records: list[tuple[str, str]] = [] for conversation_id, transcript, lineage_json in items: text = (transcript or "").strip() if not text: continue primary_mid = ( primary_user_message_id_from_lineage(lineage_json) if lineage_json else None ) source = await create_source( self._db, user_id=user_id, source_type="transcript", raw_text=text, conversation_id=conversation_id or None, lineage_json=lineage_json, primary_user_message_id=primary_mid, ) source_ids.append(source.id) for i, content in enumerate(chunk_transcript(text)): chunk = await create_chunk( self._db, source_id=source.id, user_id=user_id, content=content, chunk_index=i, ) chunk_records.append((chunk.id, content)) await self._db.flush() await self._db.commit() vectors_written = 0 embedding_retry_task_ids: list[str] = [] embedding_statuses: dict[str, int] = {} embedding_service = MemoryEmbeddingService( self._db, embedding_provider=self._embedding, ) for source_id in source_ids: result = await embedding_service.embed_source(user_id, source_id) vectors_written += int(result.get("vectors_written") or 0) status = str(result.get("status") or "unknown") embedding_statuses[status] = embedding_statuses.get(status, 0) + 1 task_id = self._schedule_embedding_retry_if_needed( user_id, source_id, result, memoir_correlation_id=memoir_correlation_id, ) if task_id: embedding_retry_task_ids.append(task_id) emb_ok = self._embedding.is_available() if self._embedding else False task_ids = self._enrichment_scheduler.schedule_many( user_id, source_ids, memoir_correlation_id=memoir_correlation_id, ) logger.info( "event=memory_ingest_batch_done user_id={} sources={} chunks={} " "vectors_written={} embedding_available={} embedding_statuses={} " "embedding_retry_tasks={} enrichment_enabled={} enrichment_tasks={}", user_id, len(source_ids), len(chunk_records), vectors_written, emb_ok, embedding_statuses, len(embedding_retry_task_ids), settings.memory_enrichment_enabled, len(task_ids), ) return source_ids def _schedule_embedding_retry_if_needed( self, user_id: str, source_id: str, embedding_result: dict, *, memoir_correlation_id: str | None = None, ) -> str | None: status = str(embedding_result.get("status") or "") if status not in {"failed", "partial"}: return None return self._embedding_scheduler.schedule( MemoryEmbeddingRequest( user_id=user_id, source_id=source_id, memoir_correlation_id=memoir_correlation_id, ) ) __all__ = ["MemoryIngestService"]