""" MemoirOrchestrator:按 segment 编排流水线,调用各 Specialist Agent。 负责:遍历 segments、按 category 聚合、调用 Specialist、更新 state; 持久化与章节生成由 process_category 回调完成。 """ from __future__ import annotations import time from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set from app.agents.memoir.batch_phase1_prep import run_batch_phase1_prep_chunked from app.agents.memoir.classification_agent import ( ClassificationAgent, _looks_like_fragment_only, ) from app.agents.memoir.classification_agent import ( _detect_stage as detect_stage_from_keywords, ) from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult from app.agents.stage_constants import ( filter_stage_slots, normalize_chapter_category, normalize_chat_stage, ) from app.agents.state_schema import MemoirStateSchema from app.core.agent_logging import agent_span, agent_summary_enabled, log_agent_detail from app.core.config import settings from app.core.logging import get_logger from app.features.conversation.models import Segment logger = get_logger(__name__) @dataclass class PreparedMemoirBatches: """Explicit batching result: updated state + segments grouped by chapter category.""" state: MemoirStateSchema category_to_segments: Dict[str, List[Segment]] #: segment id 在「LLM 判 none 且 extraction slots 为空」时加入;batch 级短路见 memoir_tasks segment_skip_story_ids: Set[str] #: 每个 segment → Phase 1 分类 chapter_category(持久化到 Segment.topic_category) segment_chapter_category: Dict[str, str] class MemoirOrchestrator: """ 回忆录生成编排器。 遍历 segments → ExtractionAgent → ClassificationAgent → 按 category 聚合 → 调用 process_category 生成叙事并持久化。 可注入 ``extraction_agent`` / ``classification_agent`` 以便测试替身。 """ def __init__( self, *, extraction_agent: ExtractionAgent | None = None, classification_agent: ClassificationAgent | None = None, ) -> None: self.extraction_agent = extraction_agent or ExtractionAgent() self.classification_agent = classification_agent or ClassificationAgent() def prepare_batches( self, *, segments: List[Segment], llm: Any, get_or_create_state: Callable[[], MemoirStateSchema], update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], llm_fast: Any | None = None, on_phase1_chunk: Optional[Callable[[int, int], None]] = None, language: str = "zh", ) -> PreparedMemoirBatches: """ 遍历 segments:Extraction → slot 更新 → Classification → 按 category 分桶。 不含锁与写章节/故事(由调用方显式执行)。 ``llm_fast``:分类与抽取专用;未传时与 ``llm`` 相同(叙事/路由仍用 ``llm``)。 """ state = get_or_create_state() category_to_segments: Dict[str, List[Segment]] = {} segment_skip_story_ids: Set[str] = set() segment_chapter_category: Dict[str, str] = {} classify_extract_llm = llm_fast if llm_fast is not None else llm # batch 路径为默认主路径(需 LLM + 开关),失败自动回退逐段 use_batch = ( bool(segments) and classify_extract_llm is not None and settings.memoir_phase1_batch_llm_enabled ) if use_batch: try: prepared_batch = self._prepare_batches_via_batch_llm( segments=segments, state=state, classify_extract_llm=classify_extract_llm, update_slot=update_slot, on_phase1_chunk=on_phase1_chunk, language=language, ) logger.info( "event=phase1_batch_path_used segment_count={} " "msg=Phase1 批处理 LLM 路径已使用", len(segments), ) return prepared_batch except Exception as e: logger.warning( "event=phase1_batch_path_fallback segment_count={} exc={} " "msg=Phase1 批处理失败,回退逐段", len(segments), e, ) for segment in segments: text = segment.user_input_text or "" seg_t0 = time.perf_counter() initial_stage = detect_stage_from_keywords( text, state.current_stage or "childhood" ) stage_slots_raw = state.slots.get(initial_stage, {}) or {} with agent_span( logger, "MemoirOrchestrator.ExtractionAgent.extract", segment_id=segment.id, ): result: ExtractionResult = self.extraction_agent.extract( user_message=text, current_stage=state.current_stage or "childhood", stage_slots=stage_slots_raw, llm=classify_extract_llm, language=language, ) fb = state.current_stage or "childhood" detected_stage = normalize_chat_stage(result.detected_stage, fb) result_slots = filter_stage_slots(detected_stage, result.slots, fb) if not result_slots: detected_stage = normalize_chat_stage(fb, fb) for slot_name, snippet in result_slots.items(): state = update_slot(detected_stage, slot_name, snippet, [segment.id]) with agent_span( logger, "MemoirOrchestrator.ClassificationAgent.classify", segment_id=segment.id, ): classify_result = self.classification_agent.classify( text=text, fallback_stage=detected_stage, llm=classify_extract_llm, segment_id=segment.id, language=language, ) chapter_category = classify_result.category if (not result_slots) and classify_result.llm_said_none: segment_skip_story_ids.add(str(segment.id)) segment_chapter_category[str(segment.id)] = chapter_category if agent_summary_enabled(): logger.info( "MemoirOrchestrator.segment segment_id={} text_len={} " "detected_stage={} category={} segment_total_ms={:.2f}", segment.id, len(text), detected_stage, chapter_category, (time.perf_counter() - seg_t0) * 1000, ) log_agent_detail( logger, "MemoirOrchestrator.segment_done segment_id={} slots={}", segment.id, list(result_slots.keys()), ) category_to_segments.setdefault(chapter_category, []).append(segment) return PreparedMemoirBatches( state=state, category_to_segments=category_to_segments, segment_skip_story_ids=segment_skip_story_ids, segment_chapter_category=segment_chapter_category, ) def _prepare_batches_via_batch_llm( self, *, segments: List[Segment], state: MemoirStateSchema, classify_extract_llm: Any, update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], on_phase1_chunk: Optional[Callable[[int, int], None]] = None, language: str = "zh", ) -> PreparedMemoirBatches: category_to_segments: Dict[str, List[Segment]] = {} segment_skip_story_ids: Set[str] = set() segment_chapter_category: Dict[str, str] = {} by_id = run_batch_phase1_prep_chunked( segments, state, classify_extract_llm, chunk_size=int(settings.memoir_phase1_batch_llm_chunk_size), on_chunk=on_phase1_chunk, language=language, ) for segment in segments: text = segment.user_input_text or "" seg_t0 = time.perf_counter() row = by_id[str(segment.id)] result_slots = dict(row.slots) fb = state.current_stage or "childhood" if not result_slots: detected_stage = normalize_chat_stage(fb, fb) else: detected_stage = normalize_chat_stage(row.detected_stage, fb) result_slots = filter_stage_slots(detected_stage, result_slots, fb) if not result_slots: detected_stage = normalize_chat_stage(fb, fb) with agent_span( logger, "MemoirOrchestrator.BatchPhase1Prep.apply", segment_id=segment.id, ): for slot_name, snippet in result_slots.items(): state = update_slot( detected_stage, slot_name, snippet, [segment.id] ) if _looks_like_fragment_only(text): chapter_category = "summary" llm_said_none = False else: raw_cat = (row.chapter_category_raw or "").strip().lower() if raw_cat == "none": chapter_category = "summary" llm_said_none = True else: chapter_category = normalize_chapter_category( row.chapter_category_raw, "summary", ) llm_said_none = False if (not result_slots) and llm_said_none: segment_skip_story_ids.add(str(segment.id)) segment_chapter_category[str(segment.id)] = chapter_category if agent_summary_enabled(): logger.info( "MemoirOrchestrator.segment(batch) segment_id={} text_len={} " "detected_stage={} category={} segment_total_ms={:.2f}", segment.id, len(text), detected_stage, chapter_category, (time.perf_counter() - seg_t0) * 1000, ) log_agent_detail( logger, "MemoirOrchestrator.segment_done(batch) segment_id={} slots={}", segment.id, list(result_slots.keys()), ) category_to_segments.setdefault(chapter_category, []).append(segment) return PreparedMemoirBatches( state=state, category_to_segments=category_to_segments, segment_skip_story_ids=segment_skip_story_ids, segment_chapter_category=segment_chapter_category, )