"""Celery:memory compaction(近重复 chunk 软排除)。""" from __future__ import annotations import asyncio import time from datetime import datetime from typing import Any from celery import shared_task from app.core.business_telemetry import business_span from app.core.db import AsyncSessionLocal, transactional from app.core.logging import get_logger from app.core.memory_compaction_schedule import ( finalize_memory_compaction_run, read_debounce_deadline_ts, release_scheduler_gate, schedule_memory_compaction_run, set_incremental_cursor_pair, ) from app.core.redis_lock import acquire_redis_lock, release_redis_lock from app.features.memory.repo import list_users_with_recent_chunks from app.features.memory.service import MemoryService from app.features.memory.constants import memory logger = get_logger(__name__) async def _list_users_with_recent_chunks_async(hours: int) -> list[str]: async with AsyncSessionLocal() as db: return await list_users_with_recent_chunks(db, hours=hours) async def _run_memory_compaction_async( user_id: str, context: dict[str, Any] | None, ) -> dict[str, Any]: async with AsyncSessionLocal() as db: async with transactional(db): service = MemoryService(db) return await service.compact_user(user_id, context) @shared_task(bind=True, ignore_result=True) def memory_compaction_sweep(self) -> dict[str, Any]: """Beat:为近期有记忆写入的用户调度 compaction(debounce 仍由 schedule 合并)。""" t0 = time.perf_counter() if not memory.compaction_enabled: return {"skipped": True, "reason": "disabled"} hours = int(memory.compaction_sweep_recent_hours) with business_span("memory.compaction.sweep", hours=hours): user_ids = asyncio.run(_list_users_with_recent_chunks_async(hours)) ctx_base: dict[str, Any] = {"trigger_source": "beat", "sweep_hours": hours} scheduled = 0 failed = 0 for uid in user_ids: try: schedule_memory_compaction_run(uid, dict(ctx_base)) scheduled += 1 except Exception as exc: failed += 1 logger.warning( "event=memory_compaction_sweep_schedule_failed user_id={} exc={} " "msg=单用户 compaction 调度失败,继续扫描", uid, exc, ) ms = (time.perf_counter() - t0) * 1000 logger.info( "event=memory_compaction_sweep_done hours={} scheduled_users={} failed_users={} " "duration_ms={:.1f} msg=记忆压缩定时扫描已调度", hours, scheduled, failed, ms, ) return {"scheduled": scheduled, "failed": failed, "hours": hours} @shared_task(bind=True, max_retries=12, default_retry_delay=20, ignore_result=True) def memory_compaction_run( self, user_id: str, context: dict[str, Any] | None = None ) -> dict[str, Any]: run_t0 = time.perf_counter() if not memory.compaction_enabled: return {"skipped": True, "reason": "disabled"} ctx = dict(context or {}) deadline = read_debounce_deadline_ts(user_id) now = time.time() if deadline is not None and now < deadline: delay = max(1.0, deadline - now) raise self.retry(countdown=int(delay)) lock = acquire_redis_lock( f"lock:memory_compaction:{user_id}", ttl_seconds=memory.compaction_lock_ttl_seconds, ) if lock is None: ms = (time.perf_counter() - run_t0) * 1000 logger.info( "event=memory_compaction_skipped user_id={} reason=lock_not_acquired " "duration_ms={:.1f} msg=记忆压缩跳过(未拿到锁)", user_id, ms, ) out = {"skipped": True, "reason": "lock_not_acquired"} finalize_memory_compaction_run( user_id, observed_deadline_ts=deadline, context=ctx, ) return out try: with business_span("memory.compaction.run"): out = asyncio.run(_run_memory_compaction_async(user_id, ctx)) if out.get("new_cursor_ts") and out.get("new_cursor_id") is not None: set_incremental_cursor_pair( user_id, datetime.fromisoformat(out["new_cursor_ts"]), str(out["new_cursor_id"]), ) finalize_memory_compaction_run( user_id, observed_deadline_ts=deadline, context=ctx, ) ms = (time.perf_counter() - run_t0) * 1000 logger.info( "event=memory_compaction_done user_id={} duration_ms={:.1f} msg=记忆压缩运行完成", user_id, ms, ) return out except Exception as exc: ms = (time.perf_counter() - run_t0) * 1000 logger.warning( "event=memory_compaction_failed user_id={} duration_ms={:.1f} err={} " "msg=记忆压缩运行失败", user_id, ms, exc, ) release_scheduler_gate(user_id) raise self.retry(exc=exc) from exc finally: release_redis_lock(lock)