""" MemoirOrchestrator:按 segment 编排流水线,调用各 Specialist Agent。 负责:遍历 segments、按 category 聚合、调用 Specialist、更新 state; 持久化与章节生成由 process_category 回调完成。 """ from __future__ import annotations import time from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set, Tuple from app.agents.memoir.batch_phase1_prep import ( STAGE_ALLOWED_SLOTS, run_batch_phase1_prep_chunked, ) from app.agents.memoir.classification_agent import ( ClassificationAgent, _looks_like_fragment_only, ) from app.agents.memoir.classification_agent import ( _detect_stage as detect_stage_from_keywords, ) from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult from app.agents.stage_constants import normalize_chapter_category, normalize_chat_stage from app.agents.state_schema import MemoirStateSchema from app.core.agent_logging import agent_span, agent_summary_enabled, log_agent_detail from app.core.config import settings from app.core.logging import get_logger from app.features.conversation.models import Segment logger = get_logger(__name__) @dataclass class PreparedMemoirBatches: """Explicit batching result: updated state + segments grouped by chapter category.""" state: MemoirStateSchema category_to_segments: Dict[str, List[Segment]] #: segment id 在「LLM 判 none 且 extraction slots 为空」时加入;batch 级短路见 memoir_tasks segment_skip_story_ids: Set[str] #: 每个 segment → Phase 1 分类 chapter_category(持久化到 Segment.topic_category) segment_chapter_category: Dict[str, str] class MemoirOrchestrator: """ 回忆录生成编排器。 遍历 segments → ExtractionAgent → ClassificationAgent → 按 category 聚合 → 调用 process_category 生成叙事并持久化。 """ def __init__(self) -> None: self.extraction_agent = ExtractionAgent() self.classification_agent = ClassificationAgent() def prepare_batches( self, *, segments: List[Segment], llm: Any, get_or_create_state: Callable[[], MemoirStateSchema], update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], llm_fast: Any | None = None, on_phase1_chunk: Optional[Callable[[int, int], None]] = None, ) -> PreparedMemoirBatches: """ 遍历 segments:Extraction → slot 更新 → Classification → 按 category 分桶。 不含锁与写章节/故事(由调用方显式执行)。 ``llm_fast``:分类与抽取专用;未传时与 ``llm`` 相同(叙事/路由仍用 ``llm``)。 """ state = get_or_create_state() category_to_segments: Dict[str, List[Segment]] = {} segment_skip_story_ids: Set[str] = set() segment_chapter_category: Dict[str, str] = {} classify_extract_llm = llm_fast if llm_fast is not None else llm # batch 路径为默认主路径(需 LLM + 开关),失败自动回退逐段 use_batch = ( bool(segments) and classify_extract_llm is not None and settings.memoir_phase1_batch_llm_enabled ) if use_batch: try: result = self._prepare_batches_via_batch_llm( segments=segments, state=state, classify_extract_llm=classify_extract_llm, update_slot=update_slot, on_phase1_chunk=on_phase1_chunk, ) logger.info( "event=phase1_batch_path_used segment_count={} " "msg=Phase1 批处理 LLM 路径已使用", len(segments), ) return result except Exception as e: logger.warning( "event=phase1_batch_path_fallback segment_count={} exc={} " "msg=Phase1 批处理失败,回退逐段", len(segments), e, ) for segment in segments: text = segment.user_input_text or "" seg_t0 = time.perf_counter() initial_stage = detect_stage_from_keywords( text, state.current_stage or "childhood" ) stage_slots_raw = state.slots.get(initial_stage, {}) or {} with agent_span( logger, "MemoirOrchestrator.ExtractionAgent.extract", segment_id=segment.id, ): result: ExtractionResult = self.extraction_agent.extract( user_message=text, current_stage=state.current_stage or "childhood", stage_slots=stage_slots_raw, llm=classify_extract_llm, ) detected_stage = result.detected_stage for slot_name, snippet in result.slots.items(): state = update_slot(detected_stage, slot_name, snippet, [segment.id]) with agent_span( logger, "MemoirOrchestrator.ClassificationAgent.classify", segment_id=segment.id, ): classify_result = self.classification_agent.classify( text=text, fallback_stage=detected_stage, llm=classify_extract_llm, segment_id=segment.id, ) chapter_category = classify_result.category if (not result.slots) and classify_result.llm_said_none: segment_skip_story_ids.add(str(segment.id)) segment_chapter_category[str(segment.id)] = chapter_category if agent_summary_enabled(): logger.info( "MemoirOrchestrator.segment segment_id={} text_len={} " "detected_stage={} category={} segment_total_ms={:.2f}", segment.id, len(text), detected_stage, chapter_category, (time.perf_counter() - seg_t0) * 1000, ) log_agent_detail( logger, "MemoirOrchestrator.segment_done segment_id={} slots={}", segment.id, list((result.slots or {}).keys()), ) category_to_segments.setdefault(chapter_category, []).append(segment) return PreparedMemoirBatches( state=state, category_to_segments=category_to_segments, segment_skip_story_ids=segment_skip_story_ids, segment_chapter_category=segment_chapter_category, ) def _prepare_batches_via_batch_llm( self, *, segments: List[Segment], state: MemoirStateSchema, classify_extract_llm: Any, update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], on_phase1_chunk: Optional[Callable[[int, int], None]] = None, ) -> PreparedMemoirBatches: category_to_segments: Dict[str, List[Segment]] = {} segment_skip_story_ids: Set[str] = set() segment_chapter_category: Dict[str, str] = {} by_id = run_batch_phase1_prep_chunked( segments, state, classify_extract_llm, chunk_size=int(settings.memoir_phase1_batch_llm_chunk_size), on_chunk=on_phase1_chunk, ) for segment in segments: text = segment.user_input_text or "" seg_t0 = time.perf_counter() row = by_id[str(segment.id)] result_slots = dict(row.slots) fb = state.current_stage or "childhood" if not result_slots: detected_stage = normalize_chat_stage(fb, fb) else: detected_stage = normalize_chat_stage(row.detected_stage, fb) allowed = STAGE_ALLOWED_SLOTS.get(detected_stage, frozenset()) result_slots = {k: v for k, v in result_slots.items() if k in allowed} if not result_slots: detected_stage = normalize_chat_stage(fb, fb) with agent_span( logger, "MemoirOrchestrator.BatchPhase1Prep.apply", segment_id=segment.id, ): for slot_name, snippet in result_slots.items(): state = update_slot( detected_stage, slot_name, snippet, [segment.id] ) if _looks_like_fragment_only(text): chapter_category = "summary" llm_said_none = False else: raw_cat = (row.chapter_category_raw or "").strip().lower() if raw_cat == "none": chapter_category = "summary" llm_said_none = True else: chapter_category = normalize_chapter_category( row.chapter_category_raw, "summary", ) llm_said_none = False if (not result_slots) and llm_said_none: segment_skip_story_ids.add(str(segment.id)) segment_chapter_category[str(segment.id)] = chapter_category if agent_summary_enabled(): logger.info( "MemoirOrchestrator.segment(batch) segment_id={} text_len={} " "detected_stage={} category={} segment_total_ms={:.2f}", segment.id, len(text), detected_stage, chapter_category, (time.perf_counter() - seg_t0) * 1000, ) log_agent_detail( logger, "MemoirOrchestrator.segment_done(batch) segment_id={} slots={}", segment.id, list(result_slots.keys()), ) category_to_segments.setdefault(chapter_category, []).append(segment) return PreparedMemoirBatches( state=state, category_to_segments=category_to_segments, segment_skip_story_ids=segment_skip_story_ids, segment_chapter_category=segment_chapter_category, ) def run( self, *, segments: List[Segment], llm: Any, user_profile: str = "", user_birth_year: Any = None, get_or_create_state: Callable[[], MemoirStateSchema], update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], acquire_lock: Callable[[str], bool], release_lock: Callable[[str], None], process_category: Callable[ [ str, List[Segment], MemoirStateSchema, str, Any, Any, ], Tuple[Any, bool], ], raise_retry: Callable[[], None], llm_fast: Any | None = None, ) -> Tuple[Set[str], int]: """ 执行回忆录流水线。 process_category(category, segments, state, user_profile, user_birth_year, llm) 返回 (chapter, has_images_to_generate)。 返回 (chapters_to_enqueue, processed_count)。 raise_retry 用于锁竞争时抛出 Celery retry。 """ prepared = self.prepare_batches( segments=segments, llm=llm, llm_fast=llm_fast, get_or_create_state=get_or_create_state, update_slot=update_slot, on_phase1_chunk=None, ) state = prepared.state chapters_to_enqueue: Set[str] = set() category_to_segments = prepared.category_to_segments # 按 category 调用 process_category:叙事生成、持久化、封面入队标记 for chapter_category, category_segments in category_to_segments.items(): if not acquire_lock(chapter_category): logger.warning( "章节锁竞争: category={}, 延迟重试", chapter_category, ) raise_retry() try: chapter, has_images = process_category( chapter_category, category_segments, state, user_profile, user_birth_year, llm, ) if chapter and has_images: chapters_to_enqueue.add(chapter.id) finally: release_lock(chapter_category) return chapters_to_enqueue, len(segments)