Files
life-echo/api/app/agents/memoir/extraction_agent.py
Kevin 07c6478742 feat(api): 访谈路径轻量门控、Memoir Phase1 批处理与叙事/记忆管线加固
- 新增 utterance_substance:短时/应答/元话语可跳过记忆检索、阶段 LLM 与资料抽取 LLM;可配置
- 输入归一化:LLM 模式默认仅语音/ASR;配置项写入 .env.example
- Memoir Phase1:可选 batch LLM 一次性抽取+分类(失败回退逐段);Extraction 空槽位时阶段与 current_stage 对齐,prompt 约束收紧
- 叙事与忠实度:narrative_safety、证据重叠/场合锚点、标题 slots 与履历短语 grounded;fidelity 解析失败 fail-open 可配置
- 章节管线:锁 TTL 上调、锁竞争 Celery 重试、Phase2 immediate singleflight 等;story_pipeline_sync / chapter_compose / memoir_tasks 联动
- Memory:compaction / repo / summarizer / evidence 小修;事实 FTS 未命中是否回退最近事实可配置
- 新增 memoir_pipeline_trace;补充 memoir_reliability 文档与多项回归/门控测试
2026-04-03 10:12:59 +08:00

86 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ExtractionAgent从用户消息中提取 5-stage 状态与 slots。
对应现有逻辑get_state_extraction_prompt + JSON 解析
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict
from app.agents.memoir.prompts import get_state_extraction_prompt
from app.agents.stage_constants import normalize_chat_stage
from app.core.langchain_llm import invoke_json_object
from app.core.logging import get_logger
from app.core.json_utils import extract_json_payload
logger = get_logger(__name__)
@dataclass
class ExtractionResult:
"""状态提取结果"""
detected_stage: str
slots: Dict[str, str]
class ExtractionAgent:
"""从用户消息中提取 detected_stage 和 slots"""
def extract(
self,
user_message: str,
current_stage: str,
stage_slots: Dict[str, Any],
llm: Any,
) -> ExtractionResult:
"""
提取结构化信息并判断阶段。
llm 需支持 .invoke(prompt) 同步调用Celery 任务内使用)。
"""
detected_stage = current_stage
extracted_slots: Dict[str, str] = {}
if not llm:
return ExtractionResult(
detected_stage=detected_stage, slots=extracted_slots
)
try:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_stage,
stage_slots={
k: v.model_dump() if hasattr(v, "model_dump") else v
for k, v in (stage_slots or {}).items()
},
)
raw = invoke_json_object(
llm,
prompt,
max_tokens=1024,
agent="ExtractionAgent.extract",
)
parsed = json.loads(extract_json_payload(raw))
raw_slots = parsed.get("slots", {}) or {}
extracted_slots = {
k: v if isinstance(v, str) else str(v) for k, v in raw_slots.items()
}
if not extracted_slots:
# 无实质 slot 时不推断阶段,避免元话语被标成任意 childhood 等(与服务端护栏一致)
detected_stage = normalize_chat_stage(
current_stage, fallback=current_stage
)
else:
raw_detected = parsed.get("detected_stage", current_stage)
detected_stage = normalize_chat_stage(
str(raw_detected) if raw_detected is not None else None,
fallback=current_stage,
)
except (json.JSONDecodeError, Exception) as e:
logger.warning("ExtractionAgent LLM 解析失败: {}", e)
return ExtractionResult(detected_stage=detected_stage, slots=extracted_slots)