Files
life-echo/api/app/agents/memoir/batch_phase1_prep.py

182 lines
6.1 KiB
Python
Raw Normal View History

"""
Phase1 批处理一次 LLM 调用完成多段的抽取 + 章节分类与逐段循环语义对齐
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List
from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt
from app.agents.memoir.schemas import BatchPhase1LLMOutput
from app.agents.stage_constants import STAGE_SLOT_KEYS
from app.agents.state_schema import MemoirStateSchema
from app.core.config import settings
from app.core.llm_call import LLMCallError, llm_json_call
from app.core.logging import get_logger
from app.features.conversation.models import Segment
logger = get_logger(__name__)
STAGE_ALLOWED_SLOTS: Dict[str, frozenset[str]] = {
k: frozenset(v) for k, v in STAGE_SLOT_KEYS.items()
}
def _slots_snapshot(state: MemoirStateSchema) -> dict:
snap: dict = {}
for stage, buckets in (state.slots or {}).items():
snap[stage] = {}
for k, v in (buckets or {}).items():
if hasattr(v, "snippet"):
sn = getattr(v, "snippet", None) or ""
elif isinstance(v, dict):
sn = (
(v.get("snippet") or "")
if isinstance(v.get("snippet"), str)
else ""
)
else:
sn = ""
snap[stage][k] = (sn or "")[:120]
return snap
@dataclass(frozen=True)
class BatchPhase1SegmentRow:
detected_stage: str
slots: Dict[str, str]
chapter_category_raw: str
def run_batch_phase1_prep(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
) -> Dict[str, BatchPhase1SegmentRow]:
"""对 segments 顺序批量调用 LLM返回 id → 行。id 集合必须与入参完全一致。"""
if not llm:
raise ValueError("batch phase1 requires llm")
if not segments:
return {}
items = [(str(s.id), (s.user_input_text or "").strip()) for s in segments]
prompt = get_batch_memoir_phase1_prep_prompt(
system_current_stage=state.current_stage or "childhood",
slots_snapshot=_slots_snapshot(state),
segment_items=items,
)
try:
parsed = llm_json_call(
llm,
prompt,
BatchPhase1LLMOutput,
max_tokens=int(settings.memoir_phase1_batch_llm_max_tokens),
agent="BatchPhase1Prep.run",
)
except LLMCallError as e:
logger.warning("batch phase1 LLM 解析失败: {}", e)
raise ValueError("batch phase1: llm parse failed") from e
rows = parsed.segments
if not rows:
raise ValueError("batch phase1: segments must be a non-empty list")
by_id: Dict[str, BatchPhase1SegmentRow] = {}
for row in rows:
sid = str(row.id).strip()
if not sid:
continue
ds = str(row.detected_stage or "").strip().lower()
slots_raw = row.slots or {}
slots = {
k: v if isinstance(v, str) else str(v)
for k, v in slots_raw.items()
if k and isinstance(k, str)
}
cat_raw = str(row.chapter_category or "")
by_id[sid] = BatchPhase1SegmentRow(
detected_stage=ds or (state.current_stage or "childhood"),
slots=slots,
chapter_category_raw=cat_raw,
)
expected = {str(s.id) for s in segments}
if by_id.keys() != expected:
missing = expected - by_id.keys()
extra = by_id.keys() - expected
logger.warning("batch phase1 id mismatch missing={} extra={}", missing, extra)
raise ValueError("batch phase1 response segment ids do not match input")
return by_id
def _run_batch_phase1_prep_chunk_with_bisect(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
) -> Dict[str, BatchPhase1SegmentRow]:
"""单块 LLM失败时如输出截断将块二等分重试直至单段。"""
try:
return run_batch_phase1_prep(segments, state, llm)
except ValueError:
if len(segments) <= 1:
raise
mid = len(segments) // 2
if mid < 1:
raise
left = _run_batch_phase1_prep_chunk_with_bisect(segments[:mid], state, llm)
right = _run_batch_phase1_prep_chunk_with_bisect(segments[mid:], state, llm)
merged = {**left, **right}
expected = {str(s.id) for s in segments}
if merged.keys() != expected:
raise ValueError(
"batch phase1 chunked bisect merge: segment ids do not match input"
) from None
return merged
def run_batch_phase1_prep_chunked(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
*,
chunk_size: int,
on_chunk: Callable[[int, int], None] | None = None,
) -> Dict[str, BatchPhase1SegmentRow]:
"""
segments chunk_size 切片多次调用 Phase1 批处理 LLM合并 by_id
单块仍失败时在块内二分回退最后回退到单段 orchestrator 外层逐段回退衔接
"""
if not segments:
return {}
if chunk_size < 1:
chunk_size = 1
n = len(segments)
total_chunks = max(1, math.ceil(n / chunk_size))
merged: Dict[str, BatchPhase1SegmentRow] = {}
for i in range(0, n, chunk_size):
chunk_idx = i // chunk_size + 1
sub = segments[i : i + chunk_size]
logger.info(
"event=batch_phase1_chunk chunk_idx={}/{} segment_count={} batch_path=chunked "
"msg=Phase1 批处理分块调用",
chunk_idx,
total_chunks,
len(sub),
)
part = _run_batch_phase1_prep_chunk_with_bisect(sub, state, llm)
merged.update(part)
if on_chunk is not None:
on_chunk(chunk_idx, total_chunks)
expected = {str(s.id) for s in segments}
if merged.keys() != expected:
missing = expected - merged.keys()
extra = merged.keys() - expected
logger.warning(
"batch phase1 chunked id mismatch missing={} extra={}",
missing,
extra,
)
raise ValueError("batch phase1 chunked: merged segment ids do not match input")
return merged