Files
life-echo/api/app/agents/memoir/batch_phase1_prep.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

189 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Phase1 批处理:一次 LLM 调用完成多段的抽取 + 章节分类(与逐段循环语义对齐)。
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List
from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt
from app.agents.memoir.schemas import BatchPhase1LLMOutput
from app.agents.state_schema import MemoirStateSchema
from app.core.config import settings
from app.core.llm_call import LLMCallError, llm_json_call
from app.core.logging import get_logger
from app.features.conversation.models import Segment
from app.features.memoir.constants import memoir
logger = get_logger(__name__)
def _slots_snapshot(state: MemoirStateSchema) -> dict:
snap: dict = {}
for stage, buckets in (state.slots or {}).items():
snap[stage] = {}
for k, v in (buckets or {}).items():
if hasattr(v, "snippet"):
sn = getattr(v, "snippet", None) or ""
elif isinstance(v, dict):
sn = (
(v.get("snippet") or "")
if isinstance(v.get("snippet"), str)
else ""
)
else:
sn = ""
snap[stage][k] = (sn or "")[:120]
return snap
@dataclass(frozen=True)
class BatchPhase1SegmentRow:
detected_stage: str
slots: Dict[str, str]
chapter_category_raw: str
def run_batch_phase1_prep(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
*,
language: str = "zh",
) -> Dict[str, BatchPhase1SegmentRow]:
"""对 segments 顺序批量调用 LLM返回 id → 行。id 集合必须与入参完全一致。"""
if not llm:
raise ValueError("batch phase1 requires llm")
if not segments:
return {}
items = [(str(s.id), (s.user_input_text or "").strip()) for s in segments]
prompt = get_batch_memoir_phase1_prep_prompt(
system_current_stage=state.current_stage or "childhood",
slots_snapshot=_slots_snapshot(state),
segment_items=items,
language=language,
)
try:
parsed = llm_json_call(
llm,
prompt,
BatchPhase1LLMOutput,
max_tokens=int(memoir.phase1_batch_llm_max_tokens),
agent="BatchPhase1Prep.run",
)
except LLMCallError as e:
logger.warning("batch phase1 LLM 解析失败: {}", e)
raise ValueError("batch phase1: llm parse failed") from e
rows = parsed.segments
if not rows:
raise ValueError("batch phase1: segments must be a non-empty list")
by_id: Dict[str, BatchPhase1SegmentRow] = {}
for row in rows:
sid = str(row.id).strip()
if not sid:
continue
ds = str(row.detected_stage or "").strip().lower()
slots_raw = row.slots or {}
slots = {
k: v if isinstance(v, str) else str(v)
for k, v in slots_raw.items()
if k and isinstance(k, str)
}
cat_raw = str(row.chapter_category or "")
by_id[sid] = BatchPhase1SegmentRow(
detected_stage=ds or (state.current_stage or "childhood"),
slots=slots,
chapter_category_raw=cat_raw,
)
expected = {str(s.id) for s in segments}
if by_id.keys() != expected:
missing = expected - by_id.keys()
extra = by_id.keys() - expected
logger.warning("batch phase1 id mismatch missing={} extra={}", missing, extra)
raise ValueError("batch phase1 response segment ids do not match input")
return by_id
def _run_batch_phase1_prep_chunk_with_bisect(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
*,
language: str = "zh",
) -> Dict[str, BatchPhase1SegmentRow]:
"""单块 LLM失败时如输出截断将块二等分重试直至单段。"""
try:
return run_batch_phase1_prep(segments, state, llm, language=language)
except ValueError:
if len(segments) <= 1:
raise
mid = len(segments) // 2
if mid < 1:
raise
left = _run_batch_phase1_prep_chunk_with_bisect(
segments[:mid], state, llm, language=language
)
right = _run_batch_phase1_prep_chunk_with_bisect(
segments[mid:], state, llm, language=language
)
merged = {**left, **right}
expected = {str(s.id) for s in segments}
if merged.keys() != expected:
raise ValueError(
"batch phase1 chunked bisect merge: segment ids do not match input"
) from None
return merged
def run_batch_phase1_prep_chunked(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
*,
chunk_size: int,
on_chunk: Callable[[int, int], None] | None = None,
language: str = "zh",
) -> Dict[str, BatchPhase1SegmentRow]:
"""
将 segments 按 chunk_size 切片多次调用 Phase1 批处理 LLM合并 by_id。
单块仍失败时在块内二分回退(最后回退到单段),与 orchestrator 外层逐段回退衔接。
"""
if not segments:
return {}
if chunk_size < 1:
chunk_size = 1
n = len(segments)
total_chunks = max(1, math.ceil(n / chunk_size))
merged: Dict[str, BatchPhase1SegmentRow] = {}
for i in range(0, n, chunk_size):
chunk_idx = i // chunk_size + 1
sub = segments[i : i + chunk_size]
logger.info(
"event=batch_phase1_chunk chunk_idx={}/{} segment_count={} batch_path=chunked "
"msg=Phase1 批处理分块调用",
chunk_idx,
total_chunks,
len(sub),
)
part = _run_batch_phase1_prep_chunk_with_bisect(
sub, state, llm, language=language
)
merged.update(part)
if on_chunk is not None:
on_chunk(chunk_idx, total_chunks)
expected = {str(s.id) for s in segments}
if merged.keys() != expected:
missing = expected - merged.keys()
extra = merged.keys() - expected
logger.warning(
"batch phase1 chunked id mismatch missing={} extra={}",
missing,
extra,
)
raise ValueError("batch phase1 chunked: merged segment ids do not match input")
return merged