Files
life-echo/api/app/features/memoir/background_runner.py

192 lines
6.5 KiB
Python
Raw Normal View History

"""回忆录后台任务聚合debounce 后派发 process_memoir_phase1flush 时触发待叙事 Phase2。"""
2026-03-19 14:36:14 +08:00
2026-03-19 10:38:11 +08:00
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass, field
from typing import Dict, List, Sequence
2026-03-19 10:38:11 +08:00
from app.core.config import settings
2026-03-19 10:38:11 +08:00
from app.core.logging import get_logger
from app.core.task_tracker import task_tracker
logger = get_logger(__name__)
def _batch_ready_for_submit(
*,
min_chars: int,
max_wait_seconds: float,
total_text_chars: int,
elapsed_seconds: float,
) -> bool:
"""字数门闸开启时,静默结束后是否应提交(不含 min_chars==0 的早退,由调用方处理)。"""
if min_chars <= 0:
return True
if total_text_chars >= min_chars:
return True
if max_wait_seconds <= 0:
return True
return elapsed_seconds >= max_wait_seconds
def _next_retry_sleep_seconds(
debounce_seconds: float,
max_wait_seconds: float,
elapsed_seconds: float,
) -> float:
"""未达字数且未超时:下次再 sleep 的秒数。"""
return min(debounce_seconds, max(0.0, max_wait_seconds - elapsed_seconds))
@dataclass
class _MemoirBatchState:
segment_ids: list[str] = field(default_factory=list)
total_text_chars: int = 0
first_queued_monotonic: float | None = None
2026-03-19 10:38:11 +08:00
class BackgroundTaskRunner:
def __init__(self, debounce_seconds: int = 5) -> None:
self.debounce_seconds = debounce_seconds
self._batch: Dict[str, _MemoirBatchState] = {}
self._timers: Dict[str, asyncio.Task[None]] = {}
def _pop_batch(self, user_id: str) -> list[str]:
st = self._batch.pop(user_id, None)
if not st or not st.segment_ids:
return []
ids = st.segment_ids
return ids
2026-03-19 10:38:11 +08:00
async def _flush_pending_phase2(self, user_id: str) -> None:
try:
from app.tasks.memoir_tasks import dispatch_pending_memoir_phase2_for_user
await asyncio.to_thread(dispatch_pending_memoir_phase2_for_user, user_id)
except Exception as e:
logger.error(
"flush Phase2 失败: user_id={} exc_type={} exc={}",
user_id,
type(e).__name__,
e,
)
2026-03-19 10:38:11 +08:00
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
try:
from app.tasks.memoir_tasks import process_memoir_phase1
2026-03-19 10:38:11 +08:00
result = process_memoir_phase1.delay(user_id, segment_ids)
2026-03-19 10:38:11 +08:00
task_id = result.id
await task_tracker.add_task(user_id, task_id, "memoir")
logger.info(
"已提交 Celery 任务: user_id={}, task_id={}, segments={}",
2026-03-19 10:38:11 +08:00
user_id,
task_id,
len(segment_ids),
)
return task_id
except Exception as e:
logger.error("提交 Celery 任务失败: {}", e)
2026-03-19 10:38:11 +08:00
return None
async def queue_message(
self, user_id: str, segment_id: str, *, text_char_count: int = 0
) -> None:
st = self._batch.setdefault(user_id, _MemoirBatchState())
if not st.segment_ids:
st.first_queued_monotonic = time.monotonic()
st.segment_ids.append(segment_id)
st.total_text_chars += max(0, text_char_count)
2026-03-19 10:38:11 +08:00
if user_id in self._timers:
self._timers[user_id].cancel()
async def delayed_submit() -> None:
2026-03-19 10:38:11 +08:00
try:
await asyncio.sleep(self.debounce_seconds)
while True:
if user_id not in self._batch:
return
batch = self._batch.get(user_id)
if not batch or not batch.segment_ids:
return
min_c = int(settings.memoir_segment_batch_min_chars)
max_w = float(settings.memoir_segment_batch_max_wait_seconds)
if min_c <= 0:
segment_ids = self._pop_batch(user_id)
if segment_ids:
await self._submit_task(user_id, segment_ids)
return
first = batch.first_queued_monotonic
if first is None:
segment_ids = self._pop_batch(user_id)
if segment_ids:
await self._submit_task(user_id, segment_ids)
return
now = time.monotonic()
elapsed = now - first
total = batch.total_text_chars
if _batch_ready_for_submit(
min_chars=min_c,
max_wait_seconds=max_w,
total_text_chars=total,
elapsed_seconds=elapsed,
):
segment_ids = self._pop_batch(user_id)
if segment_ids:
await self._submit_task(user_id, segment_ids)
return
sleep_more = _next_retry_sleep_seconds(
float(self.debounce_seconds),
max_w,
elapsed,
)
if sleep_more <= 0:
segment_ids = self._pop_batch(user_id)
if segment_ids:
await self._submit_task(user_id, segment_ids)
return
await asyncio.sleep(sleep_more)
2026-03-19 10:38:11 +08:00
except asyncio.CancelledError:
pass
except Exception as e:
logger.error("延迟提交任务失败: {}", e)
2026-03-19 10:38:11 +08:00
self._timers[user_id] = asyncio.create_task(delayed_submit())
def _dedupe_preserve_order(self, ids: Sequence[str]) -> list[str]:
seen: set[str] = set()
out: list[str] = []
for sid in ids:
if sid not in seen:
seen.add(sid)
out.append(sid)
return out
async def flush_pending(
self,
user_id: str,
*,
extra_segment_ids: Sequence[str] | None = None,
) -> str | None:
2026-03-19 10:38:11 +08:00
if user_id in self._timers:
self._timers[user_id].cancel()
del self._timers[user_id]
popped = self._pop_batch(user_id)
merged = self._dedupe_preserve_order(
list(popped) + list(extra_segment_ids or ())
)
task_id: str | None = None
if merged:
task_id = await self._submit_task(user_id, merged)
await self._flush_pending_phase2(user_id)
return task_id