配置 SSOT(TOML + .env) 统一错误契约 Auth 与事务边界 Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client 可观测性(OpenTelemetry + LGTM)
291 lines
8.2 KiB
Python
291 lines
8.2 KiB
Python
"""
|
||
回忆录流水线细粒度进度:Redis JSON 快照,以 memoir_correlation_id 为聚合根。
|
||
供 Celery worker(同步 Redis)与 internal eval API 读取。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
|
||
import redis
|
||
|
||
from app.core.logging import get_logger
|
||
from app.core.redis_sync import get_sync_redis
|
||
from app.features.memoir.constants import memoir
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _redis() -> redis.Redis:
|
||
return get_sync_redis(decode_responses=True)
|
||
|
||
|
||
def _run_key(correlation_id: str) -> str:
|
||
return f"memoir_pipeline_run:{correlation_id}"
|
||
|
||
|
||
def _phase1_index_key(phase1_task_id: str) -> str:
|
||
return f"memoir_pipeline_run:by_phase1_task:{phase1_task_id}"
|
||
|
||
|
||
def _ttl() -> int:
|
||
return int(memoir.pipeline_run_ttl_seconds)
|
||
|
||
|
||
def _empty_fanout() -> dict[str, Any]:
|
||
return {
|
||
"story_images": [],
|
||
"recompose_chapters": [],
|
||
"memory_enrichment": [],
|
||
"quality_pass": None,
|
||
"compaction": None,
|
||
}
|
||
|
||
|
||
def _default_doc(correlation_id: str) -> dict[str, Any]:
|
||
return {
|
||
"memoir_correlation_id": correlation_id,
|
||
"user_id": None,
|
||
"started_at_utc": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||
"phase1": None,
|
||
"phase2": [],
|
||
"fanout": _empty_fanout(),
|
||
}
|
||
|
||
|
||
def _merge_phase2_list(
|
||
existing: list[dict[str, Any]], updates: list[dict[str, Any]]
|
||
) -> list[dict[str, Any]]:
|
||
by_tid: dict[str, dict[str, Any]] = {}
|
||
for x in existing:
|
||
tid = str(x.get("task_id") or "").strip()
|
||
if tid:
|
||
by_tid[tid] = dict(x)
|
||
for u in updates:
|
||
tid = str(u.get("task_id") or "").strip()
|
||
if not tid:
|
||
continue
|
||
if tid in by_tid:
|
||
merged = {**by_tid[tid], **u}
|
||
by_tid[tid] = merged
|
||
else:
|
||
by_tid[tid] = dict(u)
|
||
return list(by_tid.values())
|
||
|
||
|
||
def _fanout_list_merge_key(
|
||
items: list[dict], patch_items: list[dict], id_key: str
|
||
) -> None:
|
||
by_id: dict[str, dict[str, Any]] = {}
|
||
for x in items:
|
||
k = str(x.get(id_key) or "").strip()
|
||
if k:
|
||
by_id[k] = dict(x)
|
||
for u in patch_items:
|
||
k = str(u.get(id_key) or "").strip()
|
||
if not k:
|
||
continue
|
||
if k in by_id:
|
||
by_id[k] = {**by_id[k], **u}
|
||
else:
|
||
by_id[k] = dict(u)
|
||
items.clear()
|
||
items.extend(by_id.values())
|
||
|
||
|
||
def _merge_fanout(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
|
||
out = dict(base)
|
||
for k, v in patch.items():
|
||
if k in (
|
||
"story_images",
|
||
"recompose_chapters",
|
||
"memory_enrichment",
|
||
) and isinstance(v, list):
|
||
id_key = (
|
||
"story_id"
|
||
if k == "story_images"
|
||
else "chapter_id"
|
||
if k == "recompose_chapters"
|
||
else "source_id"
|
||
)
|
||
existing = list(out.get(k) or [])
|
||
_fanout_list_merge_key(existing, v, id_key)
|
||
out[k] = existing
|
||
elif k == "quality_pass" and isinstance(v, dict):
|
||
out[k] = {**(out.get(k) or {}), **v} if out.get(k) else dict(v)
|
||
elif k == "compaction" and isinstance(v, dict):
|
||
out[k] = {**(out.get(k) or {}), **v} if out.get(k) else dict(v)
|
||
else:
|
||
out[k] = v
|
||
return out
|
||
|
||
|
||
def _merge_doc(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
|
||
out = dict(base)
|
||
for k, v in patch.items():
|
||
if k == "phase2" and isinstance(v, list):
|
||
out["phase2"] = _merge_phase2_list(list(out.get("phase2") or []), v)
|
||
elif k == "fanout" and isinstance(v, dict):
|
||
out["fanout"] = _merge_fanout(dict(out.get("fanout") or _empty_fanout()), v)
|
||
elif k == "phase1" and isinstance(v, dict):
|
||
cur = dict(out.get("phase1") or {})
|
||
for pk, pv in v.items():
|
||
if (
|
||
pk == "detail"
|
||
and isinstance(pv, dict)
|
||
and isinstance(cur.get("detail"), dict)
|
||
):
|
||
cur["detail"] = {**cur["detail"], **pv}
|
||
else:
|
||
cur[pk] = pv
|
||
out["phase1"] = cur
|
||
elif isinstance(v, dict) and isinstance(out.get(k), dict):
|
||
out[k] = {**out[k], **v}
|
||
else:
|
||
out[k] = v
|
||
return out
|
||
|
||
|
||
def merge_pipeline_run(correlation_id: str, patch: dict[str, Any]) -> None:
|
||
"""合并补丁到流水线快照(不存在则创建最小文档)。"""
|
||
cid = (correlation_id or "").strip()
|
||
if not cid:
|
||
return
|
||
try:
|
||
r = _redis()
|
||
key = _run_key(cid)
|
||
raw = r.get(key)
|
||
if raw:
|
||
doc = json.loads(raw)
|
||
else:
|
||
doc = _default_doc(cid)
|
||
doc = _merge_doc(doc, patch)
|
||
r.setex(key, _ttl(), json.dumps(doc, ensure_ascii=False))
|
||
except Exception as e:
|
||
logger.warning(
|
||
"memoir_pipeline_progress merge failed correlation_id={} err={}",
|
||
cid,
|
||
e,
|
||
)
|
||
|
||
|
||
def init_pipeline_run_from_phase1(
|
||
user_id: str,
|
||
correlation_id: str,
|
||
phase1_task_id: str,
|
||
*,
|
||
segment_count: int,
|
||
) -> None:
|
||
cid = (correlation_id or "").strip()
|
||
uid = (user_id or "").strip()
|
||
tid = (phase1_task_id or "").strip()
|
||
if not cid or not uid or not tid:
|
||
return
|
||
try:
|
||
r = _redis()
|
||
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||
doc = {
|
||
"memoir_correlation_id": cid,
|
||
"user_id": uid,
|
||
"started_at_utc": now,
|
||
"phase1": {
|
||
"task_id": tid,
|
||
"status": "running",
|
||
"step": "started",
|
||
"detail": {"segment_count": int(segment_count)},
|
||
},
|
||
"phase2": [],
|
||
"fanout": _empty_fanout(),
|
||
}
|
||
ttl = _ttl()
|
||
r.setex(_run_key(cid), ttl, json.dumps(doc, ensure_ascii=False))
|
||
r.setex(_phase1_index_key(tid), ttl, cid)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"memoir_pipeline_progress init failed correlation_id={} err={}",
|
||
cid,
|
||
e,
|
||
)
|
||
|
||
|
||
def get_pipeline_run_snapshot(correlation_id: str) -> dict[str, Any] | None:
|
||
cid = (correlation_id or "").strip()
|
||
if not cid:
|
||
return None
|
||
try:
|
||
raw = _redis().get(_run_key(cid))
|
||
if not raw:
|
||
return None
|
||
return json.loads(raw)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"memoir_pipeline_progress get failed correlation_id={} err={}",
|
||
cid,
|
||
e,
|
||
)
|
||
return None
|
||
|
||
|
||
def resolve_correlation_id_for_phase1_task(phase1_task_id: str) -> str | None:
|
||
tid = (phase1_task_id or "").strip()
|
||
if not tid:
|
||
return None
|
||
try:
|
||
cid = _redis().get(_phase1_index_key(tid))
|
||
return (cid or "").strip() or None
|
||
except Exception as e:
|
||
logger.warning(
|
||
"memoir_pipeline_progress resolve phase1_task={} err={}",
|
||
tid,
|
||
e,
|
||
)
|
||
return None
|
||
|
||
|
||
def get_pipeline_run_for_eval(
|
||
user_id: str,
|
||
*,
|
||
memoir_correlation_id: str | None = None,
|
||
phase1_task_id: str | None = None,
|
||
) -> dict[str, Any] | None:
|
||
"""Internal eval:校验 user_id 与快照一致后返回。"""
|
||
uid = (user_id or "").strip()
|
||
if not uid:
|
||
return None
|
||
cid = (memoir_correlation_id or "").strip()
|
||
if not cid and phase1_task_id:
|
||
cid = resolve_correlation_id_for_phase1_task(phase1_task_id) or ""
|
||
if not cid:
|
||
return None
|
||
snap = get_pipeline_run_snapshot(cid)
|
||
if not snap:
|
||
return None
|
||
if str(snap.get("user_id") or "").strip() != uid:
|
||
return None
|
||
return snap
|
||
|
||
|
||
def merge_fanout_item(
|
||
correlation_id: str | None,
|
||
*,
|
||
list_name: str,
|
||
id_field: str,
|
||
item_id: str,
|
||
task_id: str,
|
||
status: str,
|
||
extra: dict[str, Any] | None = None,
|
||
) -> None:
|
||
cid = (correlation_id or "").strip()
|
||
if not cid:
|
||
return
|
||
item: dict[str, Any] = {
|
||
id_field: item_id,
|
||
"task_id": task_id,
|
||
"status": status,
|
||
}
|
||
if extra:
|
||
item.update(extra)
|
||
merge_pipeline_run(cid, {"fanout": {list_name: [item]}})
|