"""回忆录后台任务聚合:debounce 后派发 process_memoir_phase1;flush 时触发待叙事 Phase2。""" from __future__ import annotations import asyncio import time from dataclasses import dataclass, field from typing import Dict, List, Sequence from app.core.config import settings from app.core.logging import get_logger from app.core.task_tracker import task_tracker logger = get_logger(__name__) def _batch_ready_for_submit( *, min_chars: int, max_wait_seconds: float, total_text_chars: int, elapsed_seconds: float, ) -> bool: """字数门闸开启时,静默结束后是否应提交(不含 min_chars==0 的早退,由调用方处理)。""" if min_chars <= 0: return True if total_text_chars >= min_chars: return True if max_wait_seconds <= 0: return True return elapsed_seconds >= max_wait_seconds def _next_retry_sleep_seconds( debounce_seconds: float, max_wait_seconds: float, elapsed_seconds: float, ) -> float: """未达字数且未超时:下次再 sleep 的秒数。""" return min(debounce_seconds, max(0.0, max_wait_seconds - elapsed_seconds)) @dataclass class _MemoirBatchState: segment_ids: list[str] = field(default_factory=list) total_text_chars: int = 0 first_queued_monotonic: float | None = None class BackgroundTaskRunner: def __init__(self, debounce_seconds: int = 5) -> None: self.debounce_seconds = debounce_seconds self._batch: Dict[str, _MemoirBatchState] = {} self._timers: Dict[str, asyncio.Task[None]] = {} def _pop_batch(self, user_id: str) -> list[str]: st = self._batch.pop(user_id, None) if not st or not st.segment_ids: return [] ids = st.segment_ids return ids async def _flush_pending_phase2(self, user_id: str) -> None: try: from app.tasks.memoir_tasks import dispatch_pending_memoir_phase2_for_user await asyncio.to_thread(dispatch_pending_memoir_phase2_for_user, user_id) except Exception as e: logger.error( "flush Phase2 失败: user_id={} exc_type={} exc={}", user_id, type(e).__name__, e, ) async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None: try: from app.tasks.memoir_tasks import process_memoir_phase1 result = process_memoir_phase1.delay(user_id, segment_ids) task_id = result.id await task_tracker.add_task(user_id, task_id, "memoir") logger.info( "已提交 Celery 任务: user_id={}, task_id={}, segments={}", user_id, task_id, len(segment_ids), ) return task_id except Exception as e: logger.error("提交 Celery 任务失败: {}", e) return None async def queue_message( self, user_id: str, segment_id: str, *, text_char_count: int = 0 ) -> None: st = self._batch.setdefault(user_id, _MemoirBatchState()) if not st.segment_ids: st.first_queued_monotonic = time.monotonic() st.segment_ids.append(segment_id) st.total_text_chars += max(0, text_char_count) if user_id in self._timers: self._timers[user_id].cancel() async def delayed_submit() -> None: try: await asyncio.sleep(self.debounce_seconds) while True: if user_id not in self._batch: return batch = self._batch.get(user_id) if not batch or not batch.segment_ids: return min_c = int(settings.memoir_segment_batch_min_chars) max_w = float(settings.memoir_segment_batch_max_wait_seconds) if min_c <= 0: segment_ids = self._pop_batch(user_id) if segment_ids: await self._submit_task(user_id, segment_ids) return first = batch.first_queued_monotonic if first is None: segment_ids = self._pop_batch(user_id) if segment_ids: await self._submit_task(user_id, segment_ids) return now = time.monotonic() elapsed = now - first total = batch.total_text_chars if _batch_ready_for_submit( min_chars=min_c, max_wait_seconds=max_w, total_text_chars=total, elapsed_seconds=elapsed, ): segment_ids = self._pop_batch(user_id) if segment_ids: await self._submit_task(user_id, segment_ids) return sleep_more = _next_retry_sleep_seconds( float(self.debounce_seconds), max_w, elapsed, ) if sleep_more <= 0: segment_ids = self._pop_batch(user_id) if segment_ids: await self._submit_task(user_id, segment_ids) return await asyncio.sleep(sleep_more) except asyncio.CancelledError: pass except Exception as e: logger.error("延迟提交任务失败: {}", e) self._timers[user_id] = asyncio.create_task(delayed_submit()) def _dedupe_preserve_order(self, ids: Sequence[str]) -> list[str]: seen: set[str] = set() out: list[str] = [] for sid in ids: if sid not in seen: seen.add(sid) out.append(sid) return out async def flush_pending( self, user_id: str, *, extra_segment_ids: Sequence[str] | None = None, ) -> str | None: if user_id in self._timers: self._timers[user_id].cancel() del self._timers[user_id] popped = self._pop_batch(user_id) merged = self._dedupe_preserve_order( list(popped) + list(extra_segment_ids or ()) ) task_id: str | None = None if merged: task_id = await self._submit_task(user_id, merged) await self._flush_pending_phase2(user_id) return task_id