Files
life-echo/api/app/features/memoir/background_runner.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

192 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""回忆录后台任务聚合debounce 后派发 process_memoir_phase1flush 时触发待叙事 Phase2。"""
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass, field
from typing import Dict, List, Sequence
from app.core.logging import get_logger
from app.core.task_tracker import task_tracker
from app.features.memoir.constants import memoir
logger = get_logger(__name__)
def _batch_ready_for_submit(
*,
min_chars: int,
max_wait_seconds: float,
total_text_chars: int,
elapsed_seconds: float,
) -> bool:
"""字数门闸开启时,静默结束后是否应提交(不含 min_chars==0 的早退,由调用方处理)。"""
if min_chars <= 0:
return True
if total_text_chars >= min_chars:
return True
if max_wait_seconds <= 0:
return True
return elapsed_seconds >= max_wait_seconds
def _next_retry_sleep_seconds(
debounce_seconds: float,
max_wait_seconds: float,
elapsed_seconds: float,
) -> float:
"""未达字数且未超时:下次再 sleep 的秒数。"""
return min(debounce_seconds, max(0.0, max_wait_seconds - elapsed_seconds))
@dataclass
class _MemoirBatchState:
segment_ids: list[str] = field(default_factory=list)
total_text_chars: int = 0
first_queued_monotonic: float | None = None
class BackgroundTaskRunner:
def __init__(self, debounce_seconds: int = 5) -> None:
self.debounce_seconds = debounce_seconds
self._batch: Dict[str, _MemoirBatchState] = {}
self._timers: Dict[str, asyncio.Task[None]] = {}
def _pop_batch(self, user_id: str) -> list[str]:
st = self._batch.pop(user_id, None)
if not st or not st.segment_ids:
return []
ids = st.segment_ids
return ids
async def _flush_pending_phase2(self, user_id: str) -> None:
try:
from app.tasks.memoir_tasks import dispatch_pending_memoir_phase2_for_user
await asyncio.to_thread(dispatch_pending_memoir_phase2_for_user, user_id)
except Exception as e:
logger.error(
"flush Phase2 失败: user_id={} exc_type={} exc={}",
user_id,
type(e).__name__,
e,
)
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
try:
from app.tasks.memoir_tasks import process_memoir_phase1
result = process_memoir_phase1.delay(user_id, segment_ids)
task_id = result.id
await task_tracker.add_task(user_id, task_id, "memoir")
logger.info(
"已提交 Celery 任务: user_id={}, task_id={}, segments={}",
user_id,
task_id,
len(segment_ids),
)
return task_id
except Exception as e:
logger.error("提交 Celery 任务失败: {}", e)
return None
async def queue_message(
self, user_id: str, segment_id: str, *, text_char_count: int = 0
) -> None:
st = self._batch.setdefault(user_id, _MemoirBatchState())
if not st.segment_ids:
st.first_queued_monotonic = time.monotonic()
st.segment_ids.append(segment_id)
st.total_text_chars += max(0, text_char_count)
if user_id in self._timers:
self._timers[user_id].cancel()
async def delayed_submit() -> None:
try:
await asyncio.sleep(self.debounce_seconds)
while True:
if user_id not in self._batch:
return
batch = self._batch.get(user_id)
if not batch or not batch.segment_ids:
return
min_c = int(memoir.segment_batch_min_chars)
max_w = float(memoir.segment_batch_max_wait_seconds)
if min_c <= 0:
segment_ids = self._pop_batch(user_id)
if segment_ids:
await self._submit_task(user_id, segment_ids)
return
first = batch.first_queued_monotonic
if first is None:
segment_ids = self._pop_batch(user_id)
if segment_ids:
await self._submit_task(user_id, segment_ids)
return
now = time.monotonic()
elapsed = now - first
total = batch.total_text_chars
if _batch_ready_for_submit(
min_chars=min_c,
max_wait_seconds=max_w,
total_text_chars=total,
elapsed_seconds=elapsed,
):
segment_ids = self._pop_batch(user_id)
if segment_ids:
await self._submit_task(user_id, segment_ids)
return
sleep_more = _next_retry_sleep_seconds(
float(self.debounce_seconds),
max_w,
elapsed,
)
if sleep_more <= 0:
segment_ids = self._pop_batch(user_id)
if segment_ids:
await self._submit_task(user_id, segment_ids)
return
await asyncio.sleep(sleep_more)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error("延迟提交任务失败: {}", e)
self._timers[user_id] = asyncio.create_task(delayed_submit())
def _dedupe_preserve_order(self, ids: Sequence[str]) -> list[str]:
seen: set[str] = set()
out: list[str] = []
for sid in ids:
if sid not in seen:
seen.add(sid)
out.append(sid)
return out
async def flush_pending(
self,
user_id: str,
*,
extra_segment_ids: Sequence[str] | None = None,
) -> str | None:
if user_id in self._timers:
self._timers[user_id].cancel()
del self._timers[user_id]
popped = self._pop_batch(user_id)
merged = self._dedupe_preserve_order(
list(popped) + list(extra_segment_ids or ())
)
task_id: str | None = None
if merged:
task_id = await self._submit_task(user_id, merged)
await self._flush_pending_phase2(user_id)
return task_id