Files
life-echo/api/app/core/memoir_pipeline_progress.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

291 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录流水线细粒度进度Redis JSON 快照,以 memoir_correlation_id 为聚合根。
供 Celery worker同步 Redis与 internal eval API 读取。
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any
import redis
from app.core.logging import get_logger
from app.core.redis_sync import get_sync_redis
from app.features.memoir.constants import memoir
logger = get_logger(__name__)
def _redis() -> redis.Redis:
return get_sync_redis(decode_responses=True)
def _run_key(correlation_id: str) -> str:
return f"memoir_pipeline_run:{correlation_id}"
def _phase1_index_key(phase1_task_id: str) -> str:
return f"memoir_pipeline_run:by_phase1_task:{phase1_task_id}"
def _ttl() -> int:
return int(memoir.pipeline_run_ttl_seconds)
def _empty_fanout() -> dict[str, Any]:
return {
"story_images": [],
"recompose_chapters": [],
"memory_enrichment": [],
"quality_pass": None,
"compaction": None,
}
def _default_doc(correlation_id: str) -> dict[str, Any]:
return {
"memoir_correlation_id": correlation_id,
"user_id": None,
"started_at_utc": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
"phase1": None,
"phase2": [],
"fanout": _empty_fanout(),
}
def _merge_phase2_list(
existing: list[dict[str, Any]], updates: list[dict[str, Any]]
) -> list[dict[str, Any]]:
by_tid: dict[str, dict[str, Any]] = {}
for x in existing:
tid = str(x.get("task_id") or "").strip()
if tid:
by_tid[tid] = dict(x)
for u in updates:
tid = str(u.get("task_id") or "").strip()
if not tid:
continue
if tid in by_tid:
merged = {**by_tid[tid], **u}
by_tid[tid] = merged
else:
by_tid[tid] = dict(u)
return list(by_tid.values())
def _fanout_list_merge_key(
items: list[dict], patch_items: list[dict], id_key: str
) -> None:
by_id: dict[str, dict[str, Any]] = {}
for x in items:
k = str(x.get(id_key) or "").strip()
if k:
by_id[k] = dict(x)
for u in patch_items:
k = str(u.get(id_key) or "").strip()
if not k:
continue
if k in by_id:
by_id[k] = {**by_id[k], **u}
else:
by_id[k] = dict(u)
items.clear()
items.extend(by_id.values())
def _merge_fanout(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
out = dict(base)
for k, v in patch.items():
if k in (
"story_images",
"recompose_chapters",
"memory_enrichment",
) and isinstance(v, list):
id_key = (
"story_id"
if k == "story_images"
else "chapter_id"
if k == "recompose_chapters"
else "source_id"
)
existing = list(out.get(k) or [])
_fanout_list_merge_key(existing, v, id_key)
out[k] = existing
elif k == "quality_pass" and isinstance(v, dict):
out[k] = {**(out.get(k) or {}), **v} if out.get(k) else dict(v)
elif k == "compaction" and isinstance(v, dict):
out[k] = {**(out.get(k) or {}), **v} if out.get(k) else dict(v)
else:
out[k] = v
return out
def _merge_doc(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
out = dict(base)
for k, v in patch.items():
if k == "phase2" and isinstance(v, list):
out["phase2"] = _merge_phase2_list(list(out.get("phase2") or []), v)
elif k == "fanout" and isinstance(v, dict):
out["fanout"] = _merge_fanout(dict(out.get("fanout") or _empty_fanout()), v)
elif k == "phase1" and isinstance(v, dict):
cur = dict(out.get("phase1") or {})
for pk, pv in v.items():
if (
pk == "detail"
and isinstance(pv, dict)
and isinstance(cur.get("detail"), dict)
):
cur["detail"] = {**cur["detail"], **pv}
else:
cur[pk] = pv
out["phase1"] = cur
elif isinstance(v, dict) and isinstance(out.get(k), dict):
out[k] = {**out[k], **v}
else:
out[k] = v
return out
def merge_pipeline_run(correlation_id: str, patch: dict[str, Any]) -> None:
"""合并补丁到流水线快照(不存在则创建最小文档)。"""
cid = (correlation_id or "").strip()
if not cid:
return
try:
r = _redis()
key = _run_key(cid)
raw = r.get(key)
if raw:
doc = json.loads(raw)
else:
doc = _default_doc(cid)
doc = _merge_doc(doc, patch)
r.setex(key, _ttl(), json.dumps(doc, ensure_ascii=False))
except Exception as e:
logger.warning(
"memoir_pipeline_progress merge failed correlation_id={} err={}",
cid,
e,
)
def init_pipeline_run_from_phase1(
user_id: str,
correlation_id: str,
phase1_task_id: str,
*,
segment_count: int,
) -> None:
cid = (correlation_id or "").strip()
uid = (user_id or "").strip()
tid = (phase1_task_id or "").strip()
if not cid or not uid or not tid:
return
try:
r = _redis()
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
doc = {
"memoir_correlation_id": cid,
"user_id": uid,
"started_at_utc": now,
"phase1": {
"task_id": tid,
"status": "running",
"step": "started",
"detail": {"segment_count": int(segment_count)},
},
"phase2": [],
"fanout": _empty_fanout(),
}
ttl = _ttl()
r.setex(_run_key(cid), ttl, json.dumps(doc, ensure_ascii=False))
r.setex(_phase1_index_key(tid), ttl, cid)
except Exception as e:
logger.warning(
"memoir_pipeline_progress init failed correlation_id={} err={}",
cid,
e,
)
def get_pipeline_run_snapshot(correlation_id: str) -> dict[str, Any] | None:
cid = (correlation_id or "").strip()
if not cid:
return None
try:
raw = _redis().get(_run_key(cid))
if not raw:
return None
return json.loads(raw)
except Exception as e:
logger.warning(
"memoir_pipeline_progress get failed correlation_id={} err={}",
cid,
e,
)
return None
def resolve_correlation_id_for_phase1_task(phase1_task_id: str) -> str | None:
tid = (phase1_task_id or "").strip()
if not tid:
return None
try:
cid = _redis().get(_phase1_index_key(tid))
return (cid or "").strip() or None
except Exception as e:
logger.warning(
"memoir_pipeline_progress resolve phase1_task={} err={}",
tid,
e,
)
return None
def get_pipeline_run_for_eval(
user_id: str,
*,
memoir_correlation_id: str | None = None,
phase1_task_id: str | None = None,
) -> dict[str, Any] | None:
"""Internal eval校验 user_id 与快照一致后返回。"""
uid = (user_id or "").strip()
if not uid:
return None
cid = (memoir_correlation_id or "").strip()
if not cid and phase1_task_id:
cid = resolve_correlation_id_for_phase1_task(phase1_task_id) or ""
if not cid:
return None
snap = get_pipeline_run_snapshot(cid)
if not snap:
return None
if str(snap.get("user_id") or "").strip() != uid:
return None
return snap
def merge_fanout_item(
correlation_id: str | None,
*,
list_name: str,
id_field: str,
item_id: str,
task_id: str,
status: str,
extra: dict[str, Any] | None = None,
) -> None:
cid = (correlation_id or "").strip()
if not cid:
return
item: dict[str, Any] = {
id_field: item_id,
"task_id": task_id,
"status": status,
}
if extra:
item.update(extra)
merge_pipeline_run(cid, {"fanout": {list_name: [item]}})