283 lines
8.4 KiB
Python
283 lines
8.4 KiB
Python
"""
|
|
Memory pipeline Celery tasks — retry embedding and enrichment after durable ingest.
|
|
|
|
Tasks are routed to ``settings.celery_memory_enrichment_queue`` (default ``memory_idle``);
|
|
run workers with ``-Q celery,memory_idle`` or a dedicated low-priority worker for that queue.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
from typing import Any, cast
|
|
|
|
from celery import shared_task
|
|
|
|
from app.core.config import settings
|
|
from app.core.db import AsyncSessionLocal
|
|
from app.core.dependencies import get_embedding_provider
|
|
from app.core.logging import get_logger
|
|
from app.core.memoir_pipeline_progress import merge_fanout_item
|
|
from app.features.memory.service import MemoryService
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
async def _enrich_memory_source_async(
|
|
user_id: str,
|
|
source_id: str,
|
|
) -> None:
|
|
async with AsyncSessionLocal() as db:
|
|
service = MemoryService(db)
|
|
await service.enrich_source(user_id, source_id, llm=None)
|
|
await db.commit()
|
|
|
|
|
|
async def _embed_memory_source_async(
|
|
user_id: str,
|
|
source_id: str,
|
|
) -> dict:
|
|
async with AsyncSessionLocal() as db:
|
|
service = MemoryService(db, embedding_provider=get_embedding_provider())
|
|
result = await service.embed_source(
|
|
user_id,
|
|
source_id,
|
|
raise_on_failure=True,
|
|
)
|
|
await db.commit()
|
|
return result
|
|
|
|
|
|
def schedule_memory_embedding(
|
|
user_id: str,
|
|
source_id: str,
|
|
*,
|
|
memoir_correlation_id: str | None = None,
|
|
) -> str | None:
|
|
"""Enqueue embedding retry for a persisted memory source."""
|
|
uid = (user_id or "").strip()
|
|
sid = (source_id or "").strip()
|
|
if not uid or not sid:
|
|
return None
|
|
q = (settings.celery_memory_enrichment_queue or "").strip() or "memory_idle"
|
|
try:
|
|
task = cast(Any, embed_memory_source)
|
|
ar = task.apply_async(
|
|
args=[uid, sid],
|
|
kwargs={"memoir_correlation_id": memoir_correlation_id},
|
|
queue=q,
|
|
)
|
|
emb_id = getattr(ar, "id", None)
|
|
if not emb_id:
|
|
return None
|
|
cid = (memoir_correlation_id or "").strip()
|
|
if cid:
|
|
merge_fanout_item(
|
|
cid,
|
|
list_name="memory_embedding",
|
|
id_field="source_id",
|
|
item_id=sid,
|
|
task_id=str(emb_id),
|
|
status="enqueued",
|
|
)
|
|
return str(emb_id)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"event=memory_embedding_schedule_failed user_id={} source_id={} exc={} exc_type={}",
|
|
uid,
|
|
sid,
|
|
e,
|
|
type(e).__name__,
|
|
)
|
|
return None
|
|
|
|
|
|
def schedule_memory_enrichment(
|
|
user_id: str,
|
|
source_id: str,
|
|
*,
|
|
memoir_correlation_id: str | None = None,
|
|
) -> str | None:
|
|
"""
|
|
Enqueue post-ingest LLM enrichment on the memory idle queue.
|
|
|
|
When ``memoir_correlation_id`` is set, records ``fanout.memory_enrichment`` as enqueued
|
|
for eval / pipeline progress (same as the former Phase1 loop).
|
|
"""
|
|
if not settings.memory_enrichment_enabled:
|
|
return None
|
|
uid = (user_id or "").strip()
|
|
sid = (source_id or "").strip()
|
|
if not uid or not sid:
|
|
return None
|
|
q = (settings.celery_memory_enrichment_queue or "").strip() or "memory_idle"
|
|
try:
|
|
task = cast(Any, enrich_memory_source)
|
|
ar = task.apply_async(
|
|
args=[uid, sid],
|
|
kwargs={"memoir_correlation_id": memoir_correlation_id},
|
|
queue=q,
|
|
)
|
|
enr_id = getattr(ar, "id", None)
|
|
if not enr_id:
|
|
return None
|
|
cid = (memoir_correlation_id or "").strip()
|
|
if cid:
|
|
merge_fanout_item(
|
|
cid,
|
|
list_name="memory_enrichment",
|
|
id_field="source_id",
|
|
item_id=sid,
|
|
task_id=str(enr_id),
|
|
status="enqueued",
|
|
)
|
|
return str(enr_id)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"event=memory_enrichment_schedule_failed user_id={} source_id={} exc={} exc_type={}",
|
|
uid,
|
|
sid,
|
|
e,
|
|
type(e).__name__,
|
|
)
|
|
return None
|
|
|
|
|
|
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def embed_memory_source(
|
|
self,
|
|
user_id: str,
|
|
source_id: str,
|
|
memoir_correlation_id: str | None = None,
|
|
):
|
|
"""Post-ingest embedding retry for persisted chunks."""
|
|
tid = str(self.request.id)
|
|
t0 = time.perf_counter()
|
|
logger.info(
|
|
"event=memory_embedding_start user_id={} source_id={} task_id={} msg=开始记忆向量化",
|
|
user_id,
|
|
source_id,
|
|
tid,
|
|
)
|
|
merge_fanout_item(
|
|
memoir_correlation_id,
|
|
list_name="memory_embedding",
|
|
id_field="source_id",
|
|
item_id=source_id,
|
|
task_id=tid,
|
|
status="running",
|
|
)
|
|
try:
|
|
result = asyncio.run(_embed_memory_source_async(user_id, source_id))
|
|
ms = (time.perf_counter() - t0) * 1000
|
|
logger.info(
|
|
"event=memory_embedding_done user_id={} source_id={} duration_ms={:.1f} status={} vectors_written={} msg=记忆向量化完成",
|
|
user_id,
|
|
source_id,
|
|
ms,
|
|
result.get("status"),
|
|
result.get("vectors_written", 0),
|
|
)
|
|
merge_fanout_item(
|
|
memoir_correlation_id,
|
|
list_name="memory_embedding",
|
|
id_field="source_id",
|
|
item_id=source_id,
|
|
task_id=tid,
|
|
status="success",
|
|
extra=result,
|
|
)
|
|
return {"source_id": source_id, **result}
|
|
except Exception as e:
|
|
ms = (time.perf_counter() - t0) * 1000
|
|
logger.warning(
|
|
"event=memory_embedding_failed user_id={} source_id={} duration_ms={:.1f} "
|
|
"exc={} exc_type={} msg=记忆向量化失败",
|
|
user_id,
|
|
source_id,
|
|
ms,
|
|
e,
|
|
type(e).__name__,
|
|
)
|
|
merge_fanout_item(
|
|
memoir_correlation_id,
|
|
list_name="memory_embedding",
|
|
id_field="source_id",
|
|
item_id=source_id,
|
|
task_id=tid,
|
|
status="failure",
|
|
extra={"error": str(e)},
|
|
)
|
|
raise self.retry(exc=e) from e
|
|
|
|
|
|
@shared_task(bind=True, max_retries=2, default_retry_delay=30)
|
|
def enrich_memory_source(
|
|
self,
|
|
user_id: str,
|
|
source_id: str,
|
|
memoir_correlation_id: str | None = None,
|
|
):
|
|
"""
|
|
Post-ingest enrichment: one LLM call → session summary + structured facts.
|
|
Runs outside the memoir Phase1 hot path so narrative generation isn't blocked.
|
|
"""
|
|
if not settings.memory_enrichment_enabled:
|
|
return {"status": "disabled"}
|
|
|
|
tid = str(self.request.id)
|
|
t0 = time.perf_counter()
|
|
logger.info(
|
|
"event=memory_enrichment_start user_id={} source_id={} task_id={} "
|
|
"msg=开始记忆富化(会话摘要+事实)",
|
|
user_id,
|
|
source_id,
|
|
tid,
|
|
)
|
|
merge_fanout_item(
|
|
memoir_correlation_id,
|
|
list_name="memory_enrichment",
|
|
id_field="source_id",
|
|
item_id=source_id,
|
|
task_id=tid,
|
|
status="running",
|
|
)
|
|
try:
|
|
asyncio.run(_enrich_memory_source_async(user_id, source_id))
|
|
ms = (time.perf_counter() - t0) * 1000
|
|
logger.info(
|
|
"event=memory_enrichment_done user_id={} source_id={} duration_ms={:.1f} "
|
|
"msg=记忆富化完成",
|
|
user_id,
|
|
source_id,
|
|
ms,
|
|
)
|
|
merge_fanout_item(
|
|
memoir_correlation_id,
|
|
list_name="memory_enrichment",
|
|
id_field="source_id",
|
|
item_id=source_id,
|
|
task_id=tid,
|
|
status="success",
|
|
)
|
|
return {"status": "success", "source_id": source_id}
|
|
except Exception as e:
|
|
ms = (time.perf_counter() - t0) * 1000
|
|
logger.warning(
|
|
"event=memory_enrichment_failed user_id={} source_id={} duration_ms={:.1f} "
|
|
"exc={} exc_type={} msg=记忆富化失败",
|
|
user_id,
|
|
source_id,
|
|
ms,
|
|
e,
|
|
type(e).__name__,
|
|
)
|
|
merge_fanout_item(
|
|
memoir_correlation_id,
|
|
list_name="memory_enrichment",
|
|
id_field="source_id",
|
|
item_id=source_id,
|
|
task_id=tid,
|
|
status="failure",
|
|
extra={"error": str(e)},
|
|
)
|
|
raise self.retry(exc=e) from e
|