配置 SSOT(TOML + .env) 统一错误契约 Auth 与事务边界 Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client 可观测性(OpenTelemetry + LGTM)
192 lines
6.5 KiB
Python
192 lines
6.5 KiB
Python
"""回忆录后台任务聚合:debounce 后派发 process_memoir_phase1;flush 时触发待叙事 Phase2。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import Dict, List, Sequence
|
||
|
||
from app.core.logging import get_logger
|
||
from app.core.task_tracker import task_tracker
|
||
from app.features.memoir.constants import memoir
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _batch_ready_for_submit(
|
||
*,
|
||
min_chars: int,
|
||
max_wait_seconds: float,
|
||
total_text_chars: int,
|
||
elapsed_seconds: float,
|
||
) -> bool:
|
||
"""字数门闸开启时,静默结束后是否应提交(不含 min_chars==0 的早退,由调用方处理)。"""
|
||
if min_chars <= 0:
|
||
return True
|
||
if total_text_chars >= min_chars:
|
||
return True
|
||
if max_wait_seconds <= 0:
|
||
return True
|
||
return elapsed_seconds >= max_wait_seconds
|
||
|
||
|
||
def _next_retry_sleep_seconds(
|
||
debounce_seconds: float,
|
||
max_wait_seconds: float,
|
||
elapsed_seconds: float,
|
||
) -> float:
|
||
"""未达字数且未超时:下次再 sleep 的秒数。"""
|
||
return min(debounce_seconds, max(0.0, max_wait_seconds - elapsed_seconds))
|
||
|
||
|
||
@dataclass
|
||
class _MemoirBatchState:
|
||
segment_ids: list[str] = field(default_factory=list)
|
||
total_text_chars: int = 0
|
||
first_queued_monotonic: float | None = None
|
||
|
||
|
||
class BackgroundTaskRunner:
|
||
def __init__(self, debounce_seconds: int = 5) -> None:
|
||
self.debounce_seconds = debounce_seconds
|
||
self._batch: Dict[str, _MemoirBatchState] = {}
|
||
self._timers: Dict[str, asyncio.Task[None]] = {}
|
||
|
||
def _pop_batch(self, user_id: str) -> list[str]:
|
||
st = self._batch.pop(user_id, None)
|
||
if not st or not st.segment_ids:
|
||
return []
|
||
ids = st.segment_ids
|
||
return ids
|
||
|
||
async def _flush_pending_phase2(self, user_id: str) -> None:
|
||
try:
|
||
from app.tasks.memoir_tasks import dispatch_pending_memoir_phase2_for_user
|
||
|
||
await asyncio.to_thread(dispatch_pending_memoir_phase2_for_user, user_id)
|
||
except Exception as e:
|
||
logger.error(
|
||
"flush Phase2 失败: user_id={} exc_type={} exc={}",
|
||
user_id,
|
||
type(e).__name__,
|
||
e,
|
||
)
|
||
|
||
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
|
||
try:
|
||
from app.tasks.memoir_tasks import process_memoir_phase1
|
||
|
||
result = process_memoir_phase1.delay(user_id, segment_ids)
|
||
task_id = result.id
|
||
await task_tracker.add_task(user_id, task_id, "memoir")
|
||
logger.info(
|
||
"已提交 Celery 任务: user_id={}, task_id={}, segments={}",
|
||
user_id,
|
||
task_id,
|
||
len(segment_ids),
|
||
)
|
||
return task_id
|
||
except Exception as e:
|
||
logger.error("提交 Celery 任务失败: {}", e)
|
||
return None
|
||
|
||
async def queue_message(
|
||
self, user_id: str, segment_id: str, *, text_char_count: int = 0
|
||
) -> None:
|
||
st = self._batch.setdefault(user_id, _MemoirBatchState())
|
||
if not st.segment_ids:
|
||
st.first_queued_monotonic = time.monotonic()
|
||
st.segment_ids.append(segment_id)
|
||
st.total_text_chars += max(0, text_char_count)
|
||
|
||
if user_id in self._timers:
|
||
self._timers[user_id].cancel()
|
||
|
||
async def delayed_submit() -> None:
|
||
try:
|
||
await asyncio.sleep(self.debounce_seconds)
|
||
while True:
|
||
if user_id not in self._batch:
|
||
return
|
||
batch = self._batch.get(user_id)
|
||
if not batch or not batch.segment_ids:
|
||
return
|
||
|
||
min_c = int(memoir.segment_batch_min_chars)
|
||
max_w = float(memoir.segment_batch_max_wait_seconds)
|
||
|
||
if min_c <= 0:
|
||
segment_ids = self._pop_batch(user_id)
|
||
if segment_ids:
|
||
await self._submit_task(user_id, segment_ids)
|
||
return
|
||
|
||
first = batch.first_queued_monotonic
|
||
if first is None:
|
||
segment_ids = self._pop_batch(user_id)
|
||
if segment_ids:
|
||
await self._submit_task(user_id, segment_ids)
|
||
return
|
||
|
||
now = time.monotonic()
|
||
elapsed = now - first
|
||
total = batch.total_text_chars
|
||
|
||
if _batch_ready_for_submit(
|
||
min_chars=min_c,
|
||
max_wait_seconds=max_w,
|
||
total_text_chars=total,
|
||
elapsed_seconds=elapsed,
|
||
):
|
||
segment_ids = self._pop_batch(user_id)
|
||
if segment_ids:
|
||
await self._submit_task(user_id, segment_ids)
|
||
return
|
||
|
||
sleep_more = _next_retry_sleep_seconds(
|
||
float(self.debounce_seconds),
|
||
max_w,
|
||
elapsed,
|
||
)
|
||
if sleep_more <= 0:
|
||
segment_ids = self._pop_batch(user_id)
|
||
if segment_ids:
|
||
await self._submit_task(user_id, segment_ids)
|
||
return
|
||
await asyncio.sleep(sleep_more)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
logger.error("延迟提交任务失败: {}", e)
|
||
|
||
self._timers[user_id] = asyncio.create_task(delayed_submit())
|
||
|
||
def _dedupe_preserve_order(self, ids: Sequence[str]) -> list[str]:
|
||
seen: set[str] = set()
|
||
out: list[str] = []
|
||
for sid in ids:
|
||
if sid not in seen:
|
||
seen.add(sid)
|
||
out.append(sid)
|
||
return out
|
||
|
||
async def flush_pending(
|
||
self,
|
||
user_id: str,
|
||
*,
|
||
extra_segment_ids: Sequence[str] | None = None,
|
||
) -> str | None:
|
||
if user_id in self._timers:
|
||
self._timers[user_id].cancel()
|
||
del self._timers[user_id]
|
||
popped = self._pop_batch(user_id)
|
||
merged = self._dedupe_preserve_order(
|
||
list(popped) + list(extra_segment_ids or ())
|
||
)
|
||
task_id: str | None = None
|
||
if merged:
|
||
task_id = await self._submit_task(user_id, merged)
|
||
await self._flush_pending_phase2(user_id)
|
||
return task_id
|