""" 回忆录流水线细粒度进度:Redis JSON 快照,以 memoir_correlation_id 为聚合根。 供 Celery worker(同步 Redis)与 internal eval API 读取。 """ from __future__ import annotations import json from datetime import datetime, timezone from typing import Any import redis from app.core.logging import get_logger from app.core.redis_sync import get_sync_redis from app.features.memoir.constants import memoir logger = get_logger(__name__) def _redis() -> redis.Redis: return get_sync_redis(decode_responses=True) def _run_key(correlation_id: str) -> str: return f"memoir_pipeline_run:{correlation_id}" def _phase1_index_key(phase1_task_id: str) -> str: return f"memoir_pipeline_run:by_phase1_task:{phase1_task_id}" def _ttl() -> int: return int(memoir.pipeline_run_ttl_seconds) def _empty_fanout() -> dict[str, Any]: return { "story_images": [], "recompose_chapters": [], "memory_enrichment": [], "quality_pass": None, "compaction": None, } def _default_doc(correlation_id: str) -> dict[str, Any]: return { "memoir_correlation_id": correlation_id, "user_id": None, "started_at_utc": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), "phase1": None, "phase2": [], "fanout": _empty_fanout(), } def _merge_phase2_list( existing: list[dict[str, Any]], updates: list[dict[str, Any]] ) -> list[dict[str, Any]]: by_tid: dict[str, dict[str, Any]] = {} for x in existing: tid = str(x.get("task_id") or "").strip() if tid: by_tid[tid] = dict(x) for u in updates: tid = str(u.get("task_id") or "").strip() if not tid: continue if tid in by_tid: merged = {**by_tid[tid], **u} by_tid[tid] = merged else: by_tid[tid] = dict(u) return list(by_tid.values()) def _fanout_list_merge_key( items: list[dict], patch_items: list[dict], id_key: str ) -> None: by_id: dict[str, dict[str, Any]] = {} for x in items: k = str(x.get(id_key) or "").strip() if k: by_id[k] = dict(x) for u in patch_items: k = str(u.get(id_key) or "").strip() if not k: continue if k in by_id: by_id[k] = {**by_id[k], **u} else: by_id[k] = dict(u) items.clear() items.extend(by_id.values()) def _merge_fanout(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: out = dict(base) for k, v in patch.items(): if k in ( "story_images", "recompose_chapters", "memory_enrichment", ) and isinstance(v, list): id_key = ( "story_id" if k == "story_images" else "chapter_id" if k == "recompose_chapters" else "source_id" ) existing = list(out.get(k) or []) _fanout_list_merge_key(existing, v, id_key) out[k] = existing elif k == "quality_pass" and isinstance(v, dict): out[k] = {**(out.get(k) or {}), **v} if out.get(k) else dict(v) elif k == "compaction" and isinstance(v, dict): out[k] = {**(out.get(k) or {}), **v} if out.get(k) else dict(v) else: out[k] = v return out def _merge_doc(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: out = dict(base) for k, v in patch.items(): if k == "phase2" and isinstance(v, list): out["phase2"] = _merge_phase2_list(list(out.get("phase2") or []), v) elif k == "fanout" and isinstance(v, dict): out["fanout"] = _merge_fanout(dict(out.get("fanout") or _empty_fanout()), v) elif k == "phase1" and isinstance(v, dict): cur = dict(out.get("phase1") or {}) for pk, pv in v.items(): if ( pk == "detail" and isinstance(pv, dict) and isinstance(cur.get("detail"), dict) ): cur["detail"] = {**cur["detail"], **pv} else: cur[pk] = pv out["phase1"] = cur elif isinstance(v, dict) and isinstance(out.get(k), dict): out[k] = {**out[k], **v} else: out[k] = v return out def merge_pipeline_run(correlation_id: str, patch: dict[str, Any]) -> None: """合并补丁到流水线快照(不存在则创建最小文档)。""" cid = (correlation_id or "").strip() if not cid: return try: r = _redis() key = _run_key(cid) raw = r.get(key) if raw: doc = json.loads(raw) else: doc = _default_doc(cid) doc = _merge_doc(doc, patch) r.setex(key, _ttl(), json.dumps(doc, ensure_ascii=False)) except Exception as e: logger.warning( "memoir_pipeline_progress merge failed correlation_id={} err={}", cid, e, ) def init_pipeline_run_from_phase1( user_id: str, correlation_id: str, phase1_task_id: str, *, segment_count: int, ) -> None: cid = (correlation_id or "").strip() uid = (user_id or "").strip() tid = (phase1_task_id or "").strip() if not cid or not uid or not tid: return try: r = _redis() now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") doc = { "memoir_correlation_id": cid, "user_id": uid, "started_at_utc": now, "phase1": { "task_id": tid, "status": "running", "step": "started", "detail": {"segment_count": int(segment_count)}, }, "phase2": [], "fanout": _empty_fanout(), } ttl = _ttl() r.setex(_run_key(cid), ttl, json.dumps(doc, ensure_ascii=False)) r.setex(_phase1_index_key(tid), ttl, cid) except Exception as e: logger.warning( "memoir_pipeline_progress init failed correlation_id={} err={}", cid, e, ) def get_pipeline_run_snapshot(correlation_id: str) -> dict[str, Any] | None: cid = (correlation_id or "").strip() if not cid: return None try: raw = _redis().get(_run_key(cid)) if not raw: return None return json.loads(raw) except Exception as e: logger.warning( "memoir_pipeline_progress get failed correlation_id={} err={}", cid, e, ) return None def resolve_correlation_id_for_phase1_task(phase1_task_id: str) -> str | None: tid = (phase1_task_id or "").strip() if not tid: return None try: cid = _redis().get(_phase1_index_key(tid)) return (cid or "").strip() or None except Exception as e: logger.warning( "memoir_pipeline_progress resolve phase1_task={} err={}", tid, e, ) return None def get_pipeline_run_for_eval( user_id: str, *, memoir_correlation_id: str | None = None, phase1_task_id: str | None = None, ) -> dict[str, Any] | None: """Internal eval:校验 user_id 与快照一致后返回。""" uid = (user_id or "").strip() if not uid: return None cid = (memoir_correlation_id or "").strip() if not cid and phase1_task_id: cid = resolve_correlation_id_for_phase1_task(phase1_task_id) or "" if not cid: return None snap = get_pipeline_run_snapshot(cid) if not snap: return None if str(snap.get("user_id") or "").strip() != uid: return None return snap def merge_fanout_item( correlation_id: str | None, *, list_name: str, id_field: str, item_id: str, task_id: str, status: str, extra: dict[str, Any] | None = None, ) -> None: cid = (correlation_id or "").strip() if not cid: return item: dict[str, Any] = { id_field: item_id, "task_id": task_id, "status": status, } if extra: item.update(extra) merge_pipeline_run(cid, {"fanout": {list_name: [item]}})