""" MemoirOrchestrator:按 segment 编排流水线,调用各 Specialist Agent。 负责:遍历 segments、按 category 聚合、调用 Specialist、更新 state; 持久化与章节生成由 process_category 回调完成。 """ from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable, Dict, List, Set, Tuple from app.agents.memoir.classification_agent import ( ClassificationAgent, ) from app.agents.memoir.classification_agent import ( _detect_stage as detect_stage_from_keywords, ) from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult from app.agents.state_schema import MemoirStateSchema from app.core.logging import get_logger from app.features.conversation.models import Segment logger = get_logger(__name__) @dataclass class PreparedMemoirBatches: """Explicit batching result: updated state + segments grouped by chapter category.""" state: MemoirStateSchema category_to_segments: Dict[str, List[Segment]] class MemoirOrchestrator: """ 回忆录生成编排器。 遍历 segments → ExtractionAgent → ClassificationAgent → 按 category 聚合 → 调用 process_category 生成叙事并持久化。 """ def __init__(self) -> None: self.extraction_agent = ExtractionAgent() self.classification_agent = ClassificationAgent() def prepare_batches( self, *, segments: List[Segment], llm: Any, get_or_create_state: Callable[[], MemoirStateSchema], update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], ) -> PreparedMemoirBatches: """ 遍历 segments:Extraction → slot 更新 → Classification → 按 category 分桶。 不含锁与写章节/故事(由调用方显式执行)。 """ state = get_or_create_state() category_to_segments: Dict[str, List[Segment]] = {} for segment in segments: text = segment.transcript_text or "" initial_stage = detect_stage_from_keywords( text, state.current_stage or "childhood" ) stage_slots_raw = state.slots.get(initial_stage, {}) or {} result: ExtractionResult = self.extraction_agent.extract( user_message=text, current_stage=state.current_stage or "childhood", stage_slots=stage_slots_raw, llm=llm, ) detected_stage = result.detected_stage for slot_name, snippet in result.slots.items(): state = update_slot(detected_stage, slot_name, snippet, [segment.id]) chapter_category = self.classification_agent.classify( text=text, fallback_stage=detected_stage, llm=llm, ) if chapter_category is None: logger.debug( "段落无回忆录价值,跳过: segment_id=%s transcript=%s", segment.id, getattr(segment, "transcript_text", None) or "", ) continue category_to_segments.setdefault(chapter_category, []).append(segment) return PreparedMemoirBatches( state=state, category_to_segments=category_to_segments, ) def run( self, *, segments: List[Segment], llm: Any, user_profile: str = "", user_birth_year: Any = None, get_or_create_state: Callable[[], MemoirStateSchema], update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], acquire_lock: Callable[[str], bool], release_lock: Callable[[str], None], process_category: Callable[ [ str, List[Segment], MemoirStateSchema, str, Any, Any, ], Tuple[Any, bool], ], raise_retry: Callable[[], None], ) -> Tuple[Set[str], int]: """ 执行回忆录流水线。 process_category(category, segments, state, user_profile, user_birth_year, llm) 返回 (chapter, has_images_to_generate)。 返回 (chapters_to_enqueue, processed_count)。 raise_retry 用于锁竞争时抛出 Celery retry。 """ prepared = self.prepare_batches( segments=segments, llm=llm, get_or_create_state=get_or_create_state, update_slot=update_slot, ) state = prepared.state chapters_to_enqueue: Set[str] = set() category_to_segments = prepared.category_to_segments # 按 category 调用 process_category:叙事生成、持久化、封面入队标记 for chapter_category, category_segments in category_to_segments.items(): if not acquire_lock(chapter_category): logger.warning( "章节锁竞争: category=%s, 延迟重试", chapter_category, ) raise_retry() try: chapter, has_images = process_category( chapter_category, category_segments, state, user_profile, user_birth_year, llm, ) if chapter and has_images: chapters_to_enqueue.add(chapter.id) finally: release_lock(chapter_category) return chapters_to_enqueue, len(segments)