Simplify AI memory pipeline
This commit is contained in:
@@ -10,6 +10,11 @@ from app.features.conversation.lineage_schemas import (
|
||||
primary_user_message_id_from_lineage,
|
||||
)
|
||||
from app.features.memory.chunker import chunk_transcript
|
||||
from app.features.memory.embedding_scheduler import (
|
||||
MemoryEmbeddingRequest,
|
||||
MemoryEmbeddingScheduler,
|
||||
)
|
||||
from app.features.memory.embedding_service import MemoryEmbeddingService
|
||||
from app.features.memory.enrichment_scheduler import (
|
||||
MemoryEnrichmentRequest,
|
||||
MemoryEnrichmentScheduler,
|
||||
@@ -17,7 +22,6 @@ from app.features.memory.enrichment_scheduler import (
|
||||
from app.features.memory.repo import (
|
||||
create_chunk,
|
||||
create_source,
|
||||
update_chunk_embedding,
|
||||
)
|
||||
from app.ports.embedding import EmbeddingProvider
|
||||
|
||||
@@ -32,10 +36,12 @@ class MemoryIngestService:
|
||||
db: AsyncSession,
|
||||
*,
|
||||
embedding_provider: EmbeddingProvider | None = None,
|
||||
embedding_scheduler: MemoryEmbeddingScheduler | None = None,
|
||||
enrichment_scheduler: MemoryEnrichmentScheduler | None = None,
|
||||
) -> None:
|
||||
self._db = db
|
||||
self._embedding = embedding_provider
|
||||
self._embedding_scheduler = embedding_scheduler or MemoryEmbeddingScheduler()
|
||||
self._enrichment_scheduler = enrichment_scheduler or MemoryEnrichmentScheduler()
|
||||
|
||||
async def ingest_transcript(
|
||||
@@ -74,19 +80,17 @@ class MemoryIngestService:
|
||||
chunk_records.append((chunk.id, content))
|
||||
|
||||
await self._db.flush()
|
||||
|
||||
vectors_written = 0
|
||||
if self._embedding and chunk_records:
|
||||
texts = [content for _, content in chunk_records]
|
||||
embeddings = await self._embedding.embed_texts(texts)
|
||||
for (chunk_id, _), emb in zip(
|
||||
chunk_records, embeddings, strict=False
|
||||
):
|
||||
if emb:
|
||||
vectors_written += 1
|
||||
await update_chunk_embedding(self._db, chunk_id, emb)
|
||||
|
||||
await self._db.commit()
|
||||
|
||||
embedding_result = await MemoryEmbeddingService(
|
||||
self._db,
|
||||
embedding_provider=self._embedding,
|
||||
).embed_source(user_id, source.id)
|
||||
embedding_task_id = self._schedule_embedding_retry_if_needed(
|
||||
user_id,
|
||||
source.id,
|
||||
embedding_result,
|
||||
)
|
||||
emb_ok = self._embedding.is_available() if self._embedding else False
|
||||
enrichment_task_id = self._enrichment_scheduler.schedule(
|
||||
MemoryEnrichmentRequest(user_id=user_id, source_id=source.id)
|
||||
@@ -94,13 +98,16 @@ class MemoryIngestService:
|
||||
|
||||
logger.info(
|
||||
"event=memory_ingest_done user_id={} conversation_id={} source_id={} "
|
||||
"chunks={} vectors_written={} embedding_available={} enrichment_enabled={} enrichment_task_id={}",
|
||||
"chunks={} vectors_written={} embedding_status={} embedding_available={} "
|
||||
"embedding_task_id={} enrichment_enabled={} enrichment_task_id={}",
|
||||
user_id,
|
||||
conversation_id,
|
||||
source.id,
|
||||
len(chunk_records),
|
||||
vectors_written,
|
||||
embedding_result.get("vectors_written", 0),
|
||||
embedding_result.get("status"),
|
||||
emb_ok,
|
||||
embedding_task_id,
|
||||
settings.memory_enrichment_enabled,
|
||||
enrichment_task_id,
|
||||
)
|
||||
@@ -152,17 +159,29 @@ class MemoryIngestService:
|
||||
chunk_records.append((chunk.id, content))
|
||||
|
||||
await self._db.flush()
|
||||
await self._db.commit()
|
||||
|
||||
vectors_written = 0
|
||||
if self._embedding and chunk_records:
|
||||
texts = [content for _, content in chunk_records]
|
||||
embeddings = await self._embedding.embed_texts(texts)
|
||||
for (chunk_id, _), emb in zip(chunk_records, embeddings, strict=False):
|
||||
if emb:
|
||||
vectors_written += 1
|
||||
await update_chunk_embedding(self._db, chunk_id, emb)
|
||||
embedding_retry_task_ids: list[str] = []
|
||||
embedding_statuses: dict[str, int] = {}
|
||||
embedding_service = MemoryEmbeddingService(
|
||||
self._db,
|
||||
embedding_provider=self._embedding,
|
||||
)
|
||||
for source_id in source_ids:
|
||||
result = await embedding_service.embed_source(user_id, source_id)
|
||||
vectors_written += int(result.get("vectors_written") or 0)
|
||||
status = str(result.get("status") or "unknown")
|
||||
embedding_statuses[status] = embedding_statuses.get(status, 0) + 1
|
||||
task_id = self._schedule_embedding_retry_if_needed(
|
||||
user_id,
|
||||
source_id,
|
||||
result,
|
||||
memoir_correlation_id=memoir_correlation_id,
|
||||
)
|
||||
if task_id:
|
||||
embedding_retry_task_ids.append(task_id)
|
||||
|
||||
await self._db.commit()
|
||||
emb_ok = self._embedding.is_available() if self._embedding else False
|
||||
task_ids = self._enrichment_scheduler.schedule_many(
|
||||
user_id,
|
||||
@@ -172,16 +191,38 @@ class MemoryIngestService:
|
||||
|
||||
logger.info(
|
||||
"event=memory_ingest_batch_done user_id={} sources={} chunks={} "
|
||||
"vectors_written={} embedding_available={} enrichment_enabled={} enrichment_tasks={}",
|
||||
"vectors_written={} embedding_available={} embedding_statuses={} "
|
||||
"embedding_retry_tasks={} enrichment_enabled={} enrichment_tasks={}",
|
||||
user_id,
|
||||
len(source_ids),
|
||||
len(chunk_records),
|
||||
vectors_written,
|
||||
emb_ok,
|
||||
embedding_statuses,
|
||||
len(embedding_retry_task_ids),
|
||||
settings.memory_enrichment_enabled,
|
||||
len(task_ids),
|
||||
)
|
||||
return source_ids
|
||||
|
||||
def _schedule_embedding_retry_if_needed(
|
||||
self,
|
||||
user_id: str,
|
||||
source_id: str,
|
||||
embedding_result: dict,
|
||||
*,
|
||||
memoir_correlation_id: str | None = None,
|
||||
) -> str | None:
|
||||
status = str(embedding_result.get("status") or "")
|
||||
if status not in {"failed", "partial"}:
|
||||
return None
|
||||
return self._embedding_scheduler.schedule(
|
||||
MemoryEmbeddingRequest(
|
||||
user_id=user_id,
|
||||
source_id=source_id,
|
||||
memoir_correlation_id=memoir_correlation_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MemoryIngestService"]
|
||||
|
||||
Reference in New Issue
Block a user