110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
"""
|
||
Phase1 批处理:一次 LLM 调用完成多段的抽取 + 章节分类(与逐段循环语义对齐)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, List
|
||
|
||
from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt
|
||
from app.agents.memoir.schemas import BatchPhase1LLMOutput
|
||
from app.agents.state_schema import MemoirStateSchema
|
||
from app.agents.stage_constants import STAGE_SLOT_KEYS
|
||
from app.core.config import settings
|
||
from app.core.llm_call import LLMCallError, llm_json_call
|
||
from app.core.logging import get_logger
|
||
from app.features.conversation.models import Segment
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
STAGE_ALLOWED_SLOTS: Dict[str, frozenset[str]] = {
|
||
k: frozenset(v) for k, v in STAGE_SLOT_KEYS.items()
|
||
}
|
||
|
||
|
||
def _slots_snapshot(state: MemoirStateSchema) -> dict:
|
||
snap: dict = {}
|
||
for stage, buckets in (state.slots or {}).items():
|
||
snap[stage] = {}
|
||
for k, v in (buckets or {}).items():
|
||
if hasattr(v, "snippet"):
|
||
sn = getattr(v, "snippet", None) or ""
|
||
elif isinstance(v, dict):
|
||
sn = (
|
||
(v.get("snippet") or "")
|
||
if isinstance(v.get("snippet"), str)
|
||
else ""
|
||
)
|
||
else:
|
||
sn = ""
|
||
snap[stage][k] = (sn or "")[:120]
|
||
return snap
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class BatchPhase1SegmentRow:
|
||
detected_stage: str
|
||
slots: Dict[str, str]
|
||
chapter_category_raw: str
|
||
|
||
|
||
def run_batch_phase1_prep(
|
||
segments: List[Segment],
|
||
state: MemoirStateSchema,
|
||
llm: Any,
|
||
) -> Dict[str, BatchPhase1SegmentRow]:
|
||
"""对 segments 顺序批量调用 LLM;返回 id → 行。id 集合必须与入参完全一致。"""
|
||
if not llm:
|
||
raise ValueError("batch phase1 requires llm")
|
||
if not segments:
|
||
return {}
|
||
items = [(str(s.id), (s.user_input_text or "").strip()) for s in segments]
|
||
prompt = get_batch_memoir_phase1_prep_prompt(
|
||
system_current_stage=state.current_stage or "childhood",
|
||
slots_snapshot=_slots_snapshot(state),
|
||
segment_items=items,
|
||
)
|
||
try:
|
||
parsed = llm_json_call(
|
||
llm,
|
||
prompt,
|
||
BatchPhase1LLMOutput,
|
||
max_tokens=int(settings.memoir_phase1_batch_llm_max_tokens),
|
||
agent="BatchPhase1Prep.run",
|
||
)
|
||
except LLMCallError as e:
|
||
logger.warning("batch phase1 LLM 解析失败: {}", e)
|
||
raise ValueError("batch phase1: llm parse failed") from e
|
||
|
||
rows = parsed.segments
|
||
if not rows:
|
||
raise ValueError("batch phase1: segments must be a non-empty list")
|
||
|
||
by_id: Dict[str, BatchPhase1SegmentRow] = {}
|
||
for row in rows:
|
||
sid = str(row.id).strip()
|
||
if not sid:
|
||
continue
|
||
ds = str(row.detected_stage or "").strip().lower()
|
||
slots_raw = row.slots or {}
|
||
slots = {
|
||
k: v if isinstance(v, str) else str(v)
|
||
for k, v in slots_raw.items()
|
||
if k and isinstance(k, str)
|
||
}
|
||
cat_raw = str(row.chapter_category or "")
|
||
by_id[sid] = BatchPhase1SegmentRow(
|
||
detected_stage=ds or (state.current_stage or "childhood"),
|
||
slots=slots,
|
||
chapter_category_raw=cat_raw,
|
||
)
|
||
|
||
expected = {str(s.id) for s in segments}
|
||
if by_id.keys() != expected:
|
||
missing = expected - by_id.keys()
|
||
extra = by_id.keys() - expected
|
||
logger.warning("batch phase1 id mismatch missing={} extra={}", missing, extra)
|
||
raise ValueError("batch phase1 response segment ids do not match input")
|
||
return by_id
|