Files
life-echo/api/app/tasks/memory_enrichment_tasks.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

286 lines
8.5 KiB
Python

"""
Memory pipeline Celery tasks — retry embedding and enrichment after durable ingest.
Tasks are routed to ``celery_defaults.memory_enrichment_queue`` (default ``memory_idle``);
run workers with ``-Q celery,memory_idle`` or a dedicated low-priority worker for that queue.
"""
import asyncio
import time
from typing import Any, cast
from celery import shared_task
from app.core.business_telemetry import business_span
from app.core.config import settings
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_embedding_provider
from app.core.logging import get_logger
from app.core.memoir_pipeline_progress import merge_fanout_item
from app.features.memory.service import MemoryService
from app.core.runtime_constants import celery_defaults
from app.features.memory.constants import memory
logger = get_logger(__name__)
async def _enrich_memory_source_async(
user_id: str,
source_id: str,
) -> None:
async with AsyncSessionLocal() as db:
service = MemoryService(db)
await service.enrich_source(user_id, source_id, llm=None)
async def _embed_memory_source_async(
user_id: str,
source_id: str,
) -> dict:
async with AsyncSessionLocal() as db:
service = MemoryService(db, embedding_provider=get_embedding_provider())
result = await service.embed_source(
user_id,
source_id,
raise_on_failure=True,
)
return result
def schedule_memory_embedding(
user_id: str,
source_id: str,
*,
memoir_correlation_id: str | None = None,
) -> str | None:
"""Enqueue embedding retry for a persisted memory source."""
uid = (user_id or "").strip()
sid = (source_id or "").strip()
if not uid or not sid:
return None
q = (celery_defaults.memory_enrichment_queue or "").strip() or "memory_idle"
try:
task = cast(Any, embed_memory_source)
ar = task.apply_async(
args=[uid, sid],
kwargs={"memoir_correlation_id": memoir_correlation_id},
queue=q,
)
emb_id = getattr(ar, "id", None)
if not emb_id:
return None
cid = (memoir_correlation_id or "").strip()
if cid:
merge_fanout_item(
cid,
list_name="memory_embedding",
id_field="source_id",
item_id=sid,
task_id=str(emb_id),
status="enqueued",
)
return str(emb_id)
except Exception as e:
logger.warning(
"event=memory_embedding_schedule_failed user_id={} source_id={} exc={} exc_type={}",
uid,
sid,
e,
type(e).__name__,
)
return None
def schedule_memory_enrichment(
user_id: str,
source_id: str,
*,
memoir_correlation_id: str | None = None,
) -> str | None:
"""
Enqueue post-ingest LLM enrichment on the memory idle queue.
When ``memoir_correlation_id`` is set, records ``fanout.memory_enrichment`` as enqueued
for eval / pipeline progress (same as the former Phase1 loop).
"""
if not memory.enrichment_enabled:
return None
uid = (user_id or "").strip()
sid = (source_id or "").strip()
if not uid or not sid:
return None
q = (celery_defaults.memory_enrichment_queue or "").strip() or "memory_idle"
try:
task = cast(Any, enrich_memory_source)
ar = task.apply_async(
args=[uid, sid],
kwargs={"memoir_correlation_id": memoir_correlation_id},
queue=q,
)
enr_id = getattr(ar, "id", None)
if not enr_id:
return None
cid = (memoir_correlation_id or "").strip()
if cid:
merge_fanout_item(
cid,
list_name="memory_enrichment",
id_field="source_id",
item_id=sid,
task_id=str(enr_id),
status="enqueued",
)
return str(enr_id)
except Exception as e:
logger.warning(
"event=memory_enrichment_schedule_failed user_id={} source_id={} exc={} exc_type={}",
uid,
sid,
e,
type(e).__name__,
)
return None
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def embed_memory_source(
self,
user_id: str,
source_id: str,
memoir_correlation_id: str | None = None,
):
"""Post-ingest embedding retry for persisted chunks."""
tid = str(self.request.id)
t0 = time.perf_counter()
logger.info(
"event=memory_embedding_start user_id={} source_id={} task_id={} msg=开始记忆向量化",
user_id,
source_id,
tid,
)
merge_fanout_item(
memoir_correlation_id,
list_name="memory_embedding",
id_field="source_id",
item_id=source_id,
task_id=tid,
status="running",
)
try:
with business_span("memory.embed_source"):
result = asyncio.run(_embed_memory_source_async(user_id, source_id))
ms = (time.perf_counter() - t0) * 1000
logger.info(
"event=memory_embedding_done user_id={} source_id={} duration_ms={:.1f} status={} vectors_written={} msg=记忆向量化完成",
user_id,
source_id,
ms,
result.get("status"),
result.get("vectors_written", 0),
)
merge_fanout_item(
memoir_correlation_id,
list_name="memory_embedding",
id_field="source_id",
item_id=source_id,
task_id=tid,
status="success",
extra=result,
)
return {"source_id": source_id, **result}
except Exception as e:
ms = (time.perf_counter() - t0) * 1000
logger.warning(
"event=memory_embedding_failed user_id={} source_id={} duration_ms={:.1f} "
"exc={} exc_type={} msg=记忆向量化失败",
user_id,
source_id,
ms,
e,
type(e).__name__,
)
merge_fanout_item(
memoir_correlation_id,
list_name="memory_embedding",
id_field="source_id",
item_id=source_id,
task_id=tid,
status="failure",
extra={"error": str(e)},
)
raise self.retry(exc=e) from e
@shared_task(bind=True, max_retries=2, default_retry_delay=30)
def enrich_memory_source(
self,
user_id: str,
source_id: str,
memoir_correlation_id: str | None = None,
):
"""
Post-ingest enrichment: one LLM call → session summary + structured facts.
Runs outside the memoir Phase1 hot path so narrative generation isn't blocked.
"""
if not memory.enrichment_enabled:
return {"status": "disabled"}
tid = str(self.request.id)
t0 = time.perf_counter()
logger.info(
"event=memory_enrichment_start user_id={} source_id={} task_id={} "
"msg=开始记忆富化(会话摘要+事实)",
user_id,
source_id,
tid,
)
merge_fanout_item(
memoir_correlation_id,
list_name="memory_enrichment",
id_field="source_id",
item_id=source_id,
task_id=tid,
status="running",
)
try:
with business_span("memory.enrich_source"):
asyncio.run(_enrich_memory_source_async(user_id, source_id))
ms = (time.perf_counter() - t0) * 1000
logger.info(
"event=memory_enrichment_done user_id={} source_id={} duration_ms={:.1f} "
"msg=记忆富化完成",
user_id,
source_id,
ms,
)
merge_fanout_item(
memoir_correlation_id,
list_name="memory_enrichment",
id_field="source_id",
item_id=source_id,
task_id=tid,
status="success",
)
return {"status": "success", "source_id": source_id}
except Exception as e:
ms = (time.perf_counter() - t0) * 1000
logger.warning(
"event=memory_enrichment_failed user_id={} source_id={} duration_ms={:.1f} "
"exc={} exc_type={} msg=记忆富化失败",
user_id,
source_id,
ms,
e,
type(e).__name__,
)
merge_fanout_item(
memoir_correlation_id,
list_name="memory_enrichment",
id_field="source_id",
item_id=source_id,
task_id=tid,
status="failure",
extra={"error": str(e)},
)
raise self.retry(exc=e) from e