183 lines
5.5 KiB
Python
183 lines
5.5 KiB
Python
"""Memory embedding service boundary."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.features.memory.repo import (
|
|
list_chunks_for_source,
|
|
set_chunk_embedding_status,
|
|
set_source_embedding_status,
|
|
update_chunk_embedding,
|
|
)
|
|
from app.ports.embedding import EmbeddingProvider
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def _short_error(exc: BaseException | str, *, max_chars: int = 500) -> str:
|
|
text = str(exc)
|
|
if len(text) > max_chars:
|
|
return text[: max_chars - 3] + "..."
|
|
return text
|
|
|
|
|
|
async def _commit_if_available(db: AsyncSession) -> None:
|
|
commit = getattr(db, "commit", None)
|
|
if commit is not None:
|
|
await commit()
|
|
|
|
|
|
class MemoryEmbeddingService:
|
|
"""Embeds persisted memory chunks and records source/chunk status."""
|
|
|
|
def __init__(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
embedding_provider: EmbeddingProvider | None = None,
|
|
) -> None:
|
|
self._db = db
|
|
self._embedding = embedding_provider
|
|
|
|
async def embed_source(
|
|
self,
|
|
user_id: str,
|
|
source_id: str,
|
|
*,
|
|
raise_on_failure: bool = False,
|
|
) -> dict:
|
|
chunks = await list_chunks_for_source(
|
|
self._db,
|
|
user_id=user_id,
|
|
source_id=source_id,
|
|
include_excluded=True,
|
|
)
|
|
if not chunks:
|
|
await set_source_embedding_status(
|
|
self._db,
|
|
source_id=source_id,
|
|
user_id=user_id,
|
|
status="skipped",
|
|
error="no_chunks",
|
|
)
|
|
await _commit_if_available(self._db)
|
|
return {"status": "skipped", "reason": "no_chunks", "chunks": 0}
|
|
|
|
if self._embedding is None:
|
|
err = "embedding_provider_missing"
|
|
await self._mark_failed(user_id, source_id, [c.id for c in chunks], err)
|
|
if raise_on_failure:
|
|
raise RuntimeError(err)
|
|
return {"status": "failed", "error": err, "chunks": len(chunks)}
|
|
|
|
await set_source_embedding_status(
|
|
self._db,
|
|
source_id=source_id,
|
|
user_id=user_id,
|
|
status="running",
|
|
error=None,
|
|
)
|
|
await _commit_if_available(self._db)
|
|
|
|
try:
|
|
texts = [c.content for c in chunks]
|
|
raw_embeddings = await self._embedding.embed_texts(texts)
|
|
embeddings = list(raw_embeddings or [])
|
|
except Exception as e:
|
|
err = _short_error(e)
|
|
await self._mark_failed(user_id, source_id, [c.id for c in chunks], err)
|
|
logger.warning(
|
|
"event=memory_embedding_failed user_id={} source_id={} chunks={} exc_type={} exc={}",
|
|
user_id,
|
|
source_id,
|
|
len(chunks),
|
|
type(e).__name__,
|
|
err,
|
|
)
|
|
if raise_on_failure:
|
|
raise
|
|
return {"status": "failed", "error": err, "chunks": len(chunks)}
|
|
|
|
vectors_written = 0
|
|
failed_chunk_ids: list[str] = []
|
|
for chunk, emb in zip(chunks, embeddings, strict=False):
|
|
if emb:
|
|
vectors_written += 1
|
|
await update_chunk_embedding(self._db, chunk.id, emb)
|
|
else:
|
|
failed_chunk_ids.append(chunk.id)
|
|
await set_chunk_embedding_status(
|
|
self._db,
|
|
chunk.id,
|
|
status="failed",
|
|
error="empty_embedding",
|
|
)
|
|
|
|
if len(embeddings) != len(chunks):
|
|
missing = chunks[len(embeddings) :]
|
|
failed_chunk_ids.extend(c.id for c in missing)
|
|
for chunk in missing:
|
|
await set_chunk_embedding_status(
|
|
self._db,
|
|
chunk.id,
|
|
status="failed",
|
|
error="embedding_count_mismatch",
|
|
)
|
|
logger.warning(
|
|
"event=memory_embedding_count_mismatch user_id={} source_id={} chunks={} embeddings={}",
|
|
user_id,
|
|
source_id,
|
|
len(chunks),
|
|
len(embeddings),
|
|
)
|
|
|
|
status = "success"
|
|
error = None
|
|
if failed_chunk_ids:
|
|
status = "partial" if vectors_written else "failed"
|
|
error = f"failed_chunks={len(failed_chunk_ids)}"
|
|
await set_source_embedding_status(
|
|
self._db,
|
|
source_id=source_id,
|
|
user_id=user_id,
|
|
status=status,
|
|
error=error,
|
|
)
|
|
await _commit_if_available(self._db)
|
|
if status == "failed" and raise_on_failure:
|
|
raise RuntimeError(error or "embedding_failed")
|
|
return {
|
|
"status": status,
|
|
"chunks": len(chunks),
|
|
"vectors_written": vectors_written,
|
|
"failed_chunks": failed_chunk_ids,
|
|
}
|
|
|
|
async def _mark_failed(
|
|
self,
|
|
user_id: str,
|
|
source_id: str,
|
|
chunk_ids: list[str],
|
|
error: str,
|
|
) -> None:
|
|
await set_source_embedding_status(
|
|
self._db,
|
|
source_id=source_id,
|
|
user_id=user_id,
|
|
status="failed",
|
|
error=error,
|
|
)
|
|
for chunk_id in chunk_ids:
|
|
await set_chunk_embedding_status(
|
|
self._db,
|
|
chunk_id,
|
|
status="failed",
|
|
error=error,
|
|
)
|
|
await _commit_if_available(self._db)
|
|
|
|
|
|
__all__ = ["MemoryEmbeddingService"]
|