"""Memory embedding service boundary.""" from __future__ import annotations from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.features.memory.repo import ( list_chunks_for_source, set_chunk_embedding_status, set_source_embedding_status, update_chunk_embedding, ) from app.ports.embedding import EmbeddingProvider logger = get_logger(__name__) def _short_error(exc: BaseException | str, *, max_chars: int = 500) -> str: text = str(exc) if len(text) > max_chars: return text[: max_chars - 3] + "..." return text async def _commit_if_available(db: AsyncSession) -> None: commit = getattr(db, "commit", None) if commit is not None: await commit() class MemoryEmbeddingService: """Embeds persisted memory chunks and records source/chunk status.""" def __init__( self, db: AsyncSession, *, embedding_provider: EmbeddingProvider | None = None, ) -> None: self._db = db self._embedding = embedding_provider async def embed_source( self, user_id: str, source_id: str, *, raise_on_failure: bool = False, ) -> dict: chunks = await list_chunks_for_source( self._db, user_id=user_id, source_id=source_id, include_excluded=True, ) if not chunks: await set_source_embedding_status( self._db, source_id=source_id, user_id=user_id, status="skipped", error="no_chunks", ) await _commit_if_available(self._db) return {"status": "skipped", "reason": "no_chunks", "chunks": 0} if self._embedding is None: err = "embedding_provider_missing" await self._mark_failed(user_id, source_id, [c.id for c in chunks], err) if raise_on_failure: raise RuntimeError(err) return {"status": "failed", "error": err, "chunks": len(chunks)} await set_source_embedding_status( self._db, source_id=source_id, user_id=user_id, status="running", error=None, ) await _commit_if_available(self._db) try: texts = [c.content for c in chunks] raw_embeddings = await self._embedding.embed_texts(texts) embeddings = list(raw_embeddings or []) except Exception as e: err = _short_error(e) await self._mark_failed(user_id, source_id, [c.id for c in chunks], err) logger.warning( "event=memory_embedding_failed user_id={} source_id={} chunks={} exc_type={} exc={}", user_id, source_id, len(chunks), type(e).__name__, err, ) if raise_on_failure: raise return {"status": "failed", "error": err, "chunks": len(chunks)} vectors_written = 0 failed_chunk_ids: list[str] = [] for chunk, emb in zip(chunks, embeddings, strict=False): if emb: vectors_written += 1 await update_chunk_embedding(self._db, chunk.id, emb) else: failed_chunk_ids.append(chunk.id) await set_chunk_embedding_status( self._db, chunk.id, status="failed", error="empty_embedding", ) if len(embeddings) != len(chunks): missing = chunks[len(embeddings) :] failed_chunk_ids.extend(c.id for c in missing) for chunk in missing: await set_chunk_embedding_status( self._db, chunk.id, status="failed", error="embedding_count_mismatch", ) logger.warning( "event=memory_embedding_count_mismatch user_id={} source_id={} chunks={} embeddings={}", user_id, source_id, len(chunks), len(embeddings), ) status = "success" error = None if failed_chunk_ids: status = "partial" if vectors_written else "failed" error = f"failed_chunks={len(failed_chunk_ids)}" await set_source_embedding_status( self._db, source_id=source_id, user_id=user_id, status=status, error=error, ) await _commit_if_available(self._db) if status == "failed" and raise_on_failure: raise RuntimeError(error or "embedding_failed") return { "status": status, "chunks": len(chunks), "vectors_written": vectors_written, "failed_chunks": failed_chunk_ids, } async def _mark_failed( self, user_id: str, source_id: str, chunk_ids: list[str], error: str, ) -> None: await set_source_embedding_status( self._db, source_id=source_id, user_id=user_id, status="failed", error=error, ) for chunk_id in chunk_ids: await set_chunk_embedding_status( self._db, chunk_id, status="failed", error=error, ) await _commit_if_available(self._db) __all__ = ["MemoryEmbeddingService"]