Files
life-echo/api/app/agents/memoir/batch_phase1_prep.py
Kevin ccdc4e4277 feat(i18n): persist language preference and thread through chat, memoir, TTS
- Add users.language_preference (Alembic 0018, default zh); capture at signup/SMS
  only; expose on auth and profile APIs
- Lite English prompts for chat and memoir; localized stage labels and agent
  names (Life Echo / 岁月知己)
- Tencent TTS: language-aware synthesis, ModelType=1 for 501004, English chunking
- WebSocket pipeline: emit all AGENT_RESPONSE segments when TTS cancels; INFO logs
  for tts_this_turn and TTS decisions; on-demand TTS logging
- Expo: device language on auth, i18n tiers/agent name, [SPLIT] streaming UX fixes
- Tests for migration, prompts, pipeline, router tts_this_turn, reply segments

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-11 16:16:49 +08:00

188 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Phase1 批处理:一次 LLM 调用完成多段的抽取 + 章节分类(与逐段循环语义对齐)。
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List
from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt
from app.agents.memoir.schemas import BatchPhase1LLMOutput
from app.agents.state_schema import MemoirStateSchema
from app.core.config import settings
from app.core.llm_call import LLMCallError, llm_json_call
from app.core.logging import get_logger
from app.features.conversation.models import Segment
logger = get_logger(__name__)
def _slots_snapshot(state: MemoirStateSchema) -> dict:
snap: dict = {}
for stage, buckets in (state.slots or {}).items():
snap[stage] = {}
for k, v in (buckets or {}).items():
if hasattr(v, "snippet"):
sn = getattr(v, "snippet", None) or ""
elif isinstance(v, dict):
sn = (
(v.get("snippet") or "")
if isinstance(v.get("snippet"), str)
else ""
)
else:
sn = ""
snap[stage][k] = (sn or "")[:120]
return snap
@dataclass(frozen=True)
class BatchPhase1SegmentRow:
detected_stage: str
slots: Dict[str, str]
chapter_category_raw: str
def run_batch_phase1_prep(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
*,
language: str = "zh",
) -> Dict[str, BatchPhase1SegmentRow]:
"""对 segments 顺序批量调用 LLM返回 id → 行。id 集合必须与入参完全一致。"""
if not llm:
raise ValueError("batch phase1 requires llm")
if not segments:
return {}
items = [(str(s.id), (s.user_input_text or "").strip()) for s in segments]
prompt = get_batch_memoir_phase1_prep_prompt(
system_current_stage=state.current_stage or "childhood",
slots_snapshot=_slots_snapshot(state),
segment_items=items,
language=language,
)
try:
parsed = llm_json_call(
llm,
prompt,
BatchPhase1LLMOutput,
max_tokens=int(settings.memoir_phase1_batch_llm_max_tokens),
agent="BatchPhase1Prep.run",
)
except LLMCallError as e:
logger.warning("batch phase1 LLM 解析失败: {}", e)
raise ValueError("batch phase1: llm parse failed") from e
rows = parsed.segments
if not rows:
raise ValueError("batch phase1: segments must be a non-empty list")
by_id: Dict[str, BatchPhase1SegmentRow] = {}
for row in rows:
sid = str(row.id).strip()
if not sid:
continue
ds = str(row.detected_stage or "").strip().lower()
slots_raw = row.slots or {}
slots = {
k: v if isinstance(v, str) else str(v)
for k, v in slots_raw.items()
if k and isinstance(k, str)
}
cat_raw = str(row.chapter_category or "")
by_id[sid] = BatchPhase1SegmentRow(
detected_stage=ds or (state.current_stage or "childhood"),
slots=slots,
chapter_category_raw=cat_raw,
)
expected = {str(s.id) for s in segments}
if by_id.keys() != expected:
missing = expected - by_id.keys()
extra = by_id.keys() - expected
logger.warning("batch phase1 id mismatch missing={} extra={}", missing, extra)
raise ValueError("batch phase1 response segment ids do not match input")
return by_id
def _run_batch_phase1_prep_chunk_with_bisect(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
*,
language: str = "zh",
) -> Dict[str, BatchPhase1SegmentRow]:
"""单块 LLM失败时如输出截断将块二等分重试直至单段。"""
try:
return run_batch_phase1_prep(segments, state, llm, language=language)
except ValueError:
if len(segments) <= 1:
raise
mid = len(segments) // 2
if mid < 1:
raise
left = _run_batch_phase1_prep_chunk_with_bisect(
segments[:mid], state, llm, language=language
)
right = _run_batch_phase1_prep_chunk_with_bisect(
segments[mid:], state, llm, language=language
)
merged = {**left, **right}
expected = {str(s.id) for s in segments}
if merged.keys() != expected:
raise ValueError(
"batch phase1 chunked bisect merge: segment ids do not match input"
) from None
return merged
def run_batch_phase1_prep_chunked(
segments: List[Segment],
state: MemoirStateSchema,
llm: Any,
*,
chunk_size: int,
on_chunk: Callable[[int, int], None] | None = None,
language: str = "zh",
) -> Dict[str, BatchPhase1SegmentRow]:
"""
将 segments 按 chunk_size 切片多次调用 Phase1 批处理 LLM合并 by_id。
单块仍失败时在块内二分回退(最后回退到单段),与 orchestrator 外层逐段回退衔接。
"""
if not segments:
return {}
if chunk_size < 1:
chunk_size = 1
n = len(segments)
total_chunks = max(1, math.ceil(n / chunk_size))
merged: Dict[str, BatchPhase1SegmentRow] = {}
for i in range(0, n, chunk_size):
chunk_idx = i // chunk_size + 1
sub = segments[i : i + chunk_size]
logger.info(
"event=batch_phase1_chunk chunk_idx={}/{} segment_count={} batch_path=chunked "
"msg=Phase1 批处理分块调用",
chunk_idx,
total_chunks,
len(sub),
)
part = _run_batch_phase1_prep_chunk_with_bisect(
sub, state, llm, language=language
)
merged.update(part)
if on_chunk is not None:
on_chunk(chunk_idx, total_chunks)
expected = {str(s.id) for s in segments}
if merged.keys() != expected:
missing = expected - merged.keys()
extra = merged.keys() - expected
logger.warning(
"batch phase1 chunked id mismatch missing={} extra={}",
missing,
extra,
)
raise ValueError("batch phase1 chunked: merged segment ids do not match input")
return merged