Simplify AI memory pipeline

This commit is contained in:
Kevin
2026-04-30 16:22:55 +08:00
parent 7617ea902c
commit 3234396254
35 changed files with 1002 additions and 579 deletions

View File

@@ -10,6 +10,11 @@ from app.features.conversation.lineage_schemas import (
primary_user_message_id_from_lineage,
)
from app.features.memory.chunker import chunk_transcript
from app.features.memory.embedding_scheduler import (
MemoryEmbeddingRequest,
MemoryEmbeddingScheduler,
)
from app.features.memory.embedding_service import MemoryEmbeddingService
from app.features.memory.enrichment_scheduler import (
MemoryEnrichmentRequest,
MemoryEnrichmentScheduler,
@@ -17,7 +22,6 @@ from app.features.memory.enrichment_scheduler import (
from app.features.memory.repo import (
create_chunk,
create_source,
update_chunk_embedding,
)
from app.ports.embedding import EmbeddingProvider
@@ -32,10 +36,12 @@ class MemoryIngestService:
db: AsyncSession,
*,
embedding_provider: EmbeddingProvider | None = None,
embedding_scheduler: MemoryEmbeddingScheduler | None = None,
enrichment_scheduler: MemoryEnrichmentScheduler | None = None,
) -> None:
self._db = db
self._embedding = embedding_provider
self._embedding_scheduler = embedding_scheduler or MemoryEmbeddingScheduler()
self._enrichment_scheduler = enrichment_scheduler or MemoryEnrichmentScheduler()
async def ingest_transcript(
@@ -74,19 +80,17 @@ class MemoryIngestService:
chunk_records.append((chunk.id, content))
await self._db.flush()
vectors_written = 0
if self._embedding and chunk_records:
texts = [content for _, content in chunk_records]
embeddings = await self._embedding.embed_texts(texts)
for (chunk_id, _), emb in zip(
chunk_records, embeddings, strict=False
):
if emb:
vectors_written += 1
await update_chunk_embedding(self._db, chunk_id, emb)
await self._db.commit()
embedding_result = await MemoryEmbeddingService(
self._db,
embedding_provider=self._embedding,
).embed_source(user_id, source.id)
embedding_task_id = self._schedule_embedding_retry_if_needed(
user_id,
source.id,
embedding_result,
)
emb_ok = self._embedding.is_available() if self._embedding else False
enrichment_task_id = self._enrichment_scheduler.schedule(
MemoryEnrichmentRequest(user_id=user_id, source_id=source.id)
@@ -94,13 +98,16 @@ class MemoryIngestService:
logger.info(
"event=memory_ingest_done user_id={} conversation_id={} source_id={} "
"chunks={} vectors_written={} embedding_available={} enrichment_enabled={} enrichment_task_id={}",
"chunks={} vectors_written={} embedding_status={} embedding_available={} "
"embedding_task_id={} enrichment_enabled={} enrichment_task_id={}",
user_id,
conversation_id,
source.id,
len(chunk_records),
vectors_written,
embedding_result.get("vectors_written", 0),
embedding_result.get("status"),
emb_ok,
embedding_task_id,
settings.memory_enrichment_enabled,
enrichment_task_id,
)
@@ -152,17 +159,29 @@ class MemoryIngestService:
chunk_records.append((chunk.id, content))
await self._db.flush()
await self._db.commit()
vectors_written = 0
if self._embedding and chunk_records:
texts = [content for _, content in chunk_records]
embeddings = await self._embedding.embed_texts(texts)
for (chunk_id, _), emb in zip(chunk_records, embeddings, strict=False):
if emb:
vectors_written += 1
await update_chunk_embedding(self._db, chunk_id, emb)
embedding_retry_task_ids: list[str] = []
embedding_statuses: dict[str, int] = {}
embedding_service = MemoryEmbeddingService(
self._db,
embedding_provider=self._embedding,
)
for source_id in source_ids:
result = await embedding_service.embed_source(user_id, source_id)
vectors_written += int(result.get("vectors_written") or 0)
status = str(result.get("status") or "unknown")
embedding_statuses[status] = embedding_statuses.get(status, 0) + 1
task_id = self._schedule_embedding_retry_if_needed(
user_id,
source_id,
result,
memoir_correlation_id=memoir_correlation_id,
)
if task_id:
embedding_retry_task_ids.append(task_id)
await self._db.commit()
emb_ok = self._embedding.is_available() if self._embedding else False
task_ids = self._enrichment_scheduler.schedule_many(
user_id,
@@ -172,16 +191,38 @@ class MemoryIngestService:
logger.info(
"event=memory_ingest_batch_done user_id={} sources={} chunks={} "
"vectors_written={} embedding_available={} enrichment_enabled={} enrichment_tasks={}",
"vectors_written={} embedding_available={} embedding_statuses={} "
"embedding_retry_tasks={} enrichment_enabled={} enrichment_tasks={}",
user_id,
len(source_ids),
len(chunk_records),
vectors_written,
emb_ok,
embedding_statuses,
len(embedding_retry_task_ids),
settings.memory_enrichment_enabled,
len(task_ids),
)
return source_ids
def _schedule_embedding_retry_if_needed(
self,
user_id: str,
source_id: str,
embedding_result: dict,
*,
memoir_correlation_id: str | None = None,
) -> str | None:
status = str(embedding_result.get("status") or "")
if status not in {"failed", "partial"}:
return None
return self._embedding_scheduler.schedule(
MemoryEmbeddingRequest(
user_id=user_id,
source_id=source_id,
memoir_correlation_id=memoir_correlation_id,
)
)
__all__ = ["MemoryIngestService"]