""" Phase1 批处理:一次 LLM 调用完成多段的抽取 + 章节分类(与逐段循环语义对齐)。 """ from __future__ import annotations import math from dataclasses import dataclass from typing import Any, Callable, Dict, List from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt from app.agents.memoir.schemas import BatchPhase1LLMOutput from app.agents.state_schema import MemoirStateSchema from app.core.config import settings from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger from app.features.conversation.models import Segment from app.features.memoir.constants import memoir logger = get_logger(__name__) def _slots_snapshot(state: MemoirStateSchema) -> dict: snap: dict = {} for stage, buckets in (state.slots or {}).items(): snap[stage] = {} for k, v in (buckets or {}).items(): if hasattr(v, "snippet"): sn = getattr(v, "snippet", None) or "" elif isinstance(v, dict): sn = ( (v.get("snippet") or "") if isinstance(v.get("snippet"), str) else "" ) else: sn = "" snap[stage][k] = (sn or "")[:120] return snap @dataclass(frozen=True) class BatchPhase1SegmentRow: detected_stage: str slots: Dict[str, str] chapter_category_raw: str def run_batch_phase1_prep( segments: List[Segment], state: MemoirStateSchema, llm: Any, *, language: str = "zh", ) -> Dict[str, BatchPhase1SegmentRow]: """对 segments 顺序批量调用 LLM;返回 id → 行。id 集合必须与入参完全一致。""" if not llm: raise ValueError("batch phase1 requires llm") if not segments: return {} items = [(str(s.id), (s.user_input_text or "").strip()) for s in segments] prompt = get_batch_memoir_phase1_prep_prompt( system_current_stage=state.current_stage or "childhood", slots_snapshot=_slots_snapshot(state), segment_items=items, language=language, ) try: parsed = llm_json_call( llm, prompt, BatchPhase1LLMOutput, max_tokens=int(memoir.phase1_batch_llm_max_tokens), agent="BatchPhase1Prep.run", ) except LLMCallError as e: logger.warning("batch phase1 LLM 解析失败: {}", e) raise ValueError("batch phase1: llm parse failed") from e rows = parsed.segments if not rows: raise ValueError("batch phase1: segments must be a non-empty list") by_id: Dict[str, BatchPhase1SegmentRow] = {} for row in rows: sid = str(row.id).strip() if not sid: continue ds = str(row.detected_stage or "").strip().lower() slots_raw = row.slots or {} slots = { k: v if isinstance(v, str) else str(v) for k, v in slots_raw.items() if k and isinstance(k, str) } cat_raw = str(row.chapter_category or "") by_id[sid] = BatchPhase1SegmentRow( detected_stage=ds or (state.current_stage or "childhood"), slots=slots, chapter_category_raw=cat_raw, ) expected = {str(s.id) for s in segments} if by_id.keys() != expected: missing = expected - by_id.keys() extra = by_id.keys() - expected logger.warning("batch phase1 id mismatch missing={} extra={}", missing, extra) raise ValueError("batch phase1 response segment ids do not match input") return by_id def _run_batch_phase1_prep_chunk_with_bisect( segments: List[Segment], state: MemoirStateSchema, llm: Any, *, language: str = "zh", ) -> Dict[str, BatchPhase1SegmentRow]: """单块 LLM;失败时(如输出截断)将块二等分重试直至单段。""" try: return run_batch_phase1_prep(segments, state, llm, language=language) except ValueError: if len(segments) <= 1: raise mid = len(segments) // 2 if mid < 1: raise left = _run_batch_phase1_prep_chunk_with_bisect( segments[:mid], state, llm, language=language ) right = _run_batch_phase1_prep_chunk_with_bisect( segments[mid:], state, llm, language=language ) merged = {**left, **right} expected = {str(s.id) for s in segments} if merged.keys() != expected: raise ValueError( "batch phase1 chunked bisect merge: segment ids do not match input" ) from None return merged def run_batch_phase1_prep_chunked( segments: List[Segment], state: MemoirStateSchema, llm: Any, *, chunk_size: int, on_chunk: Callable[[int, int], None] | None = None, language: str = "zh", ) -> Dict[str, BatchPhase1SegmentRow]: """ 将 segments 按 chunk_size 切片多次调用 Phase1 批处理 LLM,合并 by_id。 单块仍失败时在块内二分回退(最后回退到单段),与 orchestrator 外层逐段回退衔接。 """ if not segments: return {} if chunk_size < 1: chunk_size = 1 n = len(segments) total_chunks = max(1, math.ceil(n / chunk_size)) merged: Dict[str, BatchPhase1SegmentRow] = {} for i in range(0, n, chunk_size): chunk_idx = i // chunk_size + 1 sub = segments[i : i + chunk_size] logger.info( "event=batch_phase1_chunk chunk_idx={}/{} segment_count={} batch_path=chunked " "msg=Phase1 批处理分块调用", chunk_idx, total_chunks, len(sub), ) part = _run_batch_phase1_prep_chunk_with_bisect( sub, state, llm, language=language ) merged.update(part) if on_chunk is not None: on_chunk(chunk_idx, total_chunks) expected = {str(s.id) for s in segments} if merged.keys() != expected: missing = expected - merged.keys() extra = merged.keys() - expected logger.warning( "batch phase1 chunked id mismatch missing={} extra={}", missing, extra, ) raise ValueError("batch phase1 chunked: merged segment ids do not match input") return merged