""" Memory pipeline Celery tasks — retry embedding and enrichment after durable ingest. Tasks are routed to ``celery_defaults.memory_enrichment_queue`` (default ``memory_idle``); run workers with ``-Q celery,memory_idle`` or a dedicated low-priority worker for that queue. """ import asyncio import time from typing import Any, cast from celery import shared_task from app.core.business_telemetry import business_span from app.core.config import settings from app.core.db import AsyncSessionLocal from app.core.dependencies import get_embedding_provider from app.core.logging import get_logger from app.core.memoir_pipeline_progress import merge_fanout_item from app.features.memory.service import MemoryService from app.core.runtime_constants import celery_defaults from app.features.memory.constants import memory logger = get_logger(__name__) async def _enrich_memory_source_async( user_id: str, source_id: str, ) -> None: async with AsyncSessionLocal() as db: service = MemoryService(db) await service.enrich_source(user_id, source_id, llm=None) async def _embed_memory_source_async( user_id: str, source_id: str, ) -> dict: async with AsyncSessionLocal() as db: service = MemoryService(db, embedding_provider=get_embedding_provider()) result = await service.embed_source( user_id, source_id, raise_on_failure=True, ) return result def schedule_memory_embedding( user_id: str, source_id: str, *, memoir_correlation_id: str | None = None, ) -> str | None: """Enqueue embedding retry for a persisted memory source.""" uid = (user_id or "").strip() sid = (source_id or "").strip() if not uid or not sid: return None q = (celery_defaults.memory_enrichment_queue or "").strip() or "memory_idle" try: task = cast(Any, embed_memory_source) ar = task.apply_async( args=[uid, sid], kwargs={"memoir_correlation_id": memoir_correlation_id}, queue=q, ) emb_id = getattr(ar, "id", None) if not emb_id: return None cid = (memoir_correlation_id or "").strip() if cid: merge_fanout_item( cid, list_name="memory_embedding", id_field="source_id", item_id=sid, task_id=str(emb_id), status="enqueued", ) return str(emb_id) except Exception as e: logger.warning( "event=memory_embedding_schedule_failed user_id={} source_id={} exc={} exc_type={}", uid, sid, e, type(e).__name__, ) return None def schedule_memory_enrichment( user_id: str, source_id: str, *, memoir_correlation_id: str | None = None, ) -> str | None: """ Enqueue post-ingest LLM enrichment on the memory idle queue. When ``memoir_correlation_id`` is set, records ``fanout.memory_enrichment`` as enqueued for eval / pipeline progress (same as the former Phase1 loop). """ if not memory.enrichment_enabled: return None uid = (user_id or "").strip() sid = (source_id or "").strip() if not uid or not sid: return None q = (celery_defaults.memory_enrichment_queue or "").strip() or "memory_idle" try: task = cast(Any, enrich_memory_source) ar = task.apply_async( args=[uid, sid], kwargs={"memoir_correlation_id": memoir_correlation_id}, queue=q, ) enr_id = getattr(ar, "id", None) if not enr_id: return None cid = (memoir_correlation_id or "").strip() if cid: merge_fanout_item( cid, list_name="memory_enrichment", id_field="source_id", item_id=sid, task_id=str(enr_id), status="enqueued", ) return str(enr_id) except Exception as e: logger.warning( "event=memory_enrichment_schedule_failed user_id={} source_id={} exc={} exc_type={}", uid, sid, e, type(e).__name__, ) return None @shared_task(bind=True, max_retries=3, default_retry_delay=30) def embed_memory_source( self, user_id: str, source_id: str, memoir_correlation_id: str | None = None, ): """Post-ingest embedding retry for persisted chunks.""" tid = str(self.request.id) t0 = time.perf_counter() logger.info( "event=memory_embedding_start user_id={} source_id={} task_id={} msg=开始记忆向量化", user_id, source_id, tid, ) merge_fanout_item( memoir_correlation_id, list_name="memory_embedding", id_field="source_id", item_id=source_id, task_id=tid, status="running", ) try: with business_span("memory.embed_source"): result = asyncio.run(_embed_memory_source_async(user_id, source_id)) ms = (time.perf_counter() - t0) * 1000 logger.info( "event=memory_embedding_done user_id={} source_id={} duration_ms={:.1f} status={} vectors_written={} msg=记忆向量化完成", user_id, source_id, ms, result.get("status"), result.get("vectors_written", 0), ) merge_fanout_item( memoir_correlation_id, list_name="memory_embedding", id_field="source_id", item_id=source_id, task_id=tid, status="success", extra=result, ) return {"source_id": source_id, **result} except Exception as e: ms = (time.perf_counter() - t0) * 1000 logger.warning( "event=memory_embedding_failed user_id={} source_id={} duration_ms={:.1f} " "exc={} exc_type={} msg=记忆向量化失败", user_id, source_id, ms, e, type(e).__name__, ) merge_fanout_item( memoir_correlation_id, list_name="memory_embedding", id_field="source_id", item_id=source_id, task_id=tid, status="failure", extra={"error": str(e)}, ) raise self.retry(exc=e) from e @shared_task(bind=True, max_retries=2, default_retry_delay=30) def enrich_memory_source( self, user_id: str, source_id: str, memoir_correlation_id: str | None = None, ): """ Post-ingest enrichment: one LLM call → session summary + structured facts. Runs outside the memoir Phase1 hot path so narrative generation isn't blocked. """ if not memory.enrichment_enabled: return {"status": "disabled"} tid = str(self.request.id) t0 = time.perf_counter() logger.info( "event=memory_enrichment_start user_id={} source_id={} task_id={} " "msg=开始记忆富化(会话摘要+事实)", user_id, source_id, tid, ) merge_fanout_item( memoir_correlation_id, list_name="memory_enrichment", id_field="source_id", item_id=source_id, task_id=tid, status="running", ) try: with business_span("memory.enrich_source"): asyncio.run(_enrich_memory_source_async(user_id, source_id)) ms = (time.perf_counter() - t0) * 1000 logger.info( "event=memory_enrichment_done user_id={} source_id={} duration_ms={:.1f} " "msg=记忆富化完成", user_id, source_id, ms, ) merge_fanout_item( memoir_correlation_id, list_name="memory_enrichment", id_field="source_id", item_id=source_id, task_id=tid, status="success", ) return {"status": "success", "source_id": source_id} except Exception as e: ms = (time.perf_counter() - t0) * 1000 logger.warning( "event=memory_enrichment_failed user_id={} source_id={} duration_ms={:.1f} " "exc={} exc_type={} msg=记忆富化失败", user_id, source_id, ms, e, type(e).__name__, ) merge_fanout_item( memoir_correlation_id, list_name="memory_enrichment", id_field="source_id", item_id=source_id, task_id=tid, status="failure", extra={"error": str(e)}, ) raise self.retry(exc=e) from e