Files
life-echo/api/app/core/memoir_pipeline_progress.py

291 lines
8.2 KiB
Python
Raw Normal View History

"""
回忆录流水线细粒度进度Redis JSON 快照 memoir_correlation_id 为聚合根
Celery worker同步 Redis internal eval API 读取
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any
import redis
from app.core.logging import get_logger
from app.core.redis_sync import get_sync_redis
from app.features.memoir.constants import memoir
logger = get_logger(__name__)
def _redis() -> redis.Redis:
return get_sync_redis(decode_responses=True)
def _run_key(correlation_id: str) -> str:
return f"memoir_pipeline_run:{correlation_id}"
def _phase1_index_key(phase1_task_id: str) -> str:
return f"memoir_pipeline_run:by_phase1_task:{phase1_task_id}"
def _ttl() -> int:
return int(memoir.pipeline_run_ttl_seconds)
def _empty_fanout() -> dict[str, Any]:
return {
"story_images": [],
"recompose_chapters": [],
"memory_enrichment": [],
"quality_pass": None,
"compaction": None,
}
def _default_doc(correlation_id: str) -> dict[str, Any]:
return {
"memoir_correlation_id": correlation_id,
"user_id": None,
"started_at_utc": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
"phase1": None,
"phase2": [],
"fanout": _empty_fanout(),
}
def _merge_phase2_list(
existing: list[dict[str, Any]], updates: list[dict[str, Any]]
) -> list[dict[str, Any]]:
by_tid: dict[str, dict[str, Any]] = {}
for x in existing:
tid = str(x.get("task_id") or "").strip()
if tid:
by_tid[tid] = dict(x)
for u in updates:
tid = str(u.get("task_id") or "").strip()
if not tid:
continue
if tid in by_tid:
merged = {**by_tid[tid], **u}
by_tid[tid] = merged
else:
by_tid[tid] = dict(u)
return list(by_tid.values())
def _fanout_list_merge_key(
items: list[dict], patch_items: list[dict], id_key: str
) -> None:
by_id: dict[str, dict[str, Any]] = {}
for x in items:
k = str(x.get(id_key) or "").strip()
if k:
by_id[k] = dict(x)
for u in patch_items:
k = str(u.get(id_key) or "").strip()
if not k:
continue
if k in by_id:
by_id[k] = {**by_id[k], **u}
else:
by_id[k] = dict(u)
items.clear()
items.extend(by_id.values())
def _merge_fanout(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
out = dict(base)
for k, v in patch.items():
if k in (
"story_images",
"recompose_chapters",
"memory_enrichment",
) and isinstance(v, list):
id_key = (
"story_id"
if k == "story_images"
else "chapter_id"
if k == "recompose_chapters"
else "source_id"
)
existing = list(out.get(k) or [])
_fanout_list_merge_key(existing, v, id_key)
out[k] = existing
elif k == "quality_pass" and isinstance(v, dict):
out[k] = {**(out.get(k) or {}), **v} if out.get(k) else dict(v)
elif k == "compaction" and isinstance(v, dict):
out[k] = {**(out.get(k) or {}), **v} if out.get(k) else dict(v)
else:
out[k] = v
return out
def _merge_doc(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]:
out = dict(base)
for k, v in patch.items():
if k == "phase2" and isinstance(v, list):
out["phase2"] = _merge_phase2_list(list(out.get("phase2") or []), v)
elif k == "fanout" and isinstance(v, dict):
out["fanout"] = _merge_fanout(dict(out.get("fanout") or _empty_fanout()), v)
elif k == "phase1" and isinstance(v, dict):
cur = dict(out.get("phase1") or {})
for pk, pv in v.items():
if (
pk == "detail"
and isinstance(pv, dict)
and isinstance(cur.get("detail"), dict)
):
cur["detail"] = {**cur["detail"], **pv}
else:
cur[pk] = pv
out["phase1"] = cur
elif isinstance(v, dict) and isinstance(out.get(k), dict):
out[k] = {**out[k], **v}
else:
out[k] = v
return out
def merge_pipeline_run(correlation_id: str, patch: dict[str, Any]) -> None:
"""合并补丁到流水线快照(不存在则创建最小文档)。"""
cid = (correlation_id or "").strip()
if not cid:
return
try:
r = _redis()
key = _run_key(cid)
raw = r.get(key)
if raw:
doc = json.loads(raw)
else:
doc = _default_doc(cid)
doc = _merge_doc(doc, patch)
r.setex(key, _ttl(), json.dumps(doc, ensure_ascii=False))
except Exception as e:
logger.warning(
"memoir_pipeline_progress merge failed correlation_id={} err={}",
cid,
e,
)
def init_pipeline_run_from_phase1(
user_id: str,
correlation_id: str,
phase1_task_id: str,
*,
segment_count: int,
) -> None:
cid = (correlation_id or "").strip()
uid = (user_id or "").strip()
tid = (phase1_task_id or "").strip()
if not cid or not uid or not tid:
return
try:
r = _redis()
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
doc = {
"memoir_correlation_id": cid,
"user_id": uid,
"started_at_utc": now,
"phase1": {
"task_id": tid,
"status": "running",
"step": "started",
"detail": {"segment_count": int(segment_count)},
},
"phase2": [],
"fanout": _empty_fanout(),
}
ttl = _ttl()
r.setex(_run_key(cid), ttl, json.dumps(doc, ensure_ascii=False))
r.setex(_phase1_index_key(tid), ttl, cid)
except Exception as e:
logger.warning(
"memoir_pipeline_progress init failed correlation_id={} err={}",
cid,
e,
)
def get_pipeline_run_snapshot(correlation_id: str) -> dict[str, Any] | None:
cid = (correlation_id or "").strip()
if not cid:
return None
try:
raw = _redis().get(_run_key(cid))
if not raw:
return None
return json.loads(raw)
except Exception as e:
logger.warning(
"memoir_pipeline_progress get failed correlation_id={} err={}",
cid,
e,
)
return None
def resolve_correlation_id_for_phase1_task(phase1_task_id: str) -> str | None:
tid = (phase1_task_id or "").strip()
if not tid:
return None
try:
cid = _redis().get(_phase1_index_key(tid))
return (cid or "").strip() or None
except Exception as e:
logger.warning(
"memoir_pipeline_progress resolve phase1_task={} err={}",
tid,
e,
)
return None
def get_pipeline_run_for_eval(
user_id: str,
*,
memoir_correlation_id: str | None = None,
phase1_task_id: str | None = None,
) -> dict[str, Any] | None:
"""Internal eval校验 user_id 与快照一致后返回。"""
uid = (user_id or "").strip()
if not uid:
return None
cid = (memoir_correlation_id or "").strip()
if not cid and phase1_task_id:
cid = resolve_correlation_id_for_phase1_task(phase1_task_id) or ""
if not cid:
return None
snap = get_pipeline_run_snapshot(cid)
if not snap:
return None
if str(snap.get("user_id") or "").strip() != uid:
return None
return snap
def merge_fanout_item(
correlation_id: str | None,
*,
list_name: str,
id_field: str,
item_id: str,
task_id: str,
status: str,
extra: dict[str, Any] | None = None,
) -> None:
cid = (correlation_id or "").strip()
if not cid:
return
item: dict[str, Any] = {
id_field: item_id,
"task_id": task_id,
"status": status,
}
if extra:
item.update(extra)
merge_pipeline_run(cid, {"fanout": {list_name: [item]}})