From 43d1689e9c29b655c68b71aad7baafb9c6db254a Mon Sep 17 00:00:00 2001 From: Kevin Date: Fri, 3 Apr 2026 13:34:27 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E7=BB=9F=E4=B8=80=20LLM=20JSON=20?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E5=B1=82=20llm=5Fjson=5Fcall=EF=BC=8C?= =?UTF-8?q?=E6=8C=89=E5=9F=9F=20Schema=20=E8=BF=81=E7=A7=BB=20chat/memoir?= =?UTF-8?q?=20agents?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/agents/chat/output_rules.py | 13 + api/app/agents/chat/profile_agent.py | 197 ++++----- api/app/agents/chat/prompts.py | 3 + api/app/agents/chat/prompts_conversation.py | 20 +- api/app/agents/chat/prompts_profile.py | 15 +- api/app/agents/chat/schemas.py | 19 + api/app/agents/chat/stage_detection.py | 51 ++- api/app/agents/chat/stage_prompts.py | 6 +- api/app/agents/memoir/batch_phase1_prep.py | 61 ++- api/app/agents/memoir/classification_agent.py | 11 +- api/app/agents/memoir/extraction_agent.py | 18 +- api/app/agents/memoir/fidelity_check_agent.py | 31 +- api/app/agents/memoir/narrative_agent.py | 25 +- api/app/agents/memoir/prompts.py | 58 +-- api/app/agents/memoir/schemas.py | 53 +++ api/app/agents/memoir/story_route_agent.py | 35 +- api/app/agents/stage_constants.py | 26 ++ api/app/core/config.py | 10 + api/app/core/langchain_llm.py | 39 +- api/app/core/llm_call.py | 402 ++++++++++++++++++ api/docs/llm-json-mode.md | 3 +- api/docs/memoir_reliability.md | 7 + api/tests/test_chat_stage_detection_gates.py | 9 +- api/tests/test_fidelity_gate.py | 13 +- api/tests/test_llm_json_call.py | 150 +++++++ api/tests/test_stage_slot_registry.py | 13 + api/tests/test_stage_validation.py | 33 +- .../test_story_route_prompts_and_behavior.py | 37 +- 28 files changed, 1006 insertions(+), 352 deletions(-) create mode 100644 api/app/agents/chat/output_rules.py create mode 100644 api/app/agents/chat/schemas.py create mode 100644 api/app/agents/memoir/schemas.py create mode 100644 api/app/core/llm_call.py create mode 100644 api/tests/test_llm_json_call.py create mode 100644 api/tests/test_stage_slot_registry.py diff --git a/api/app/agents/chat/output_rules.py b/api/app/agents/chat/output_rules.py new file mode 100644 index 0000000..5fd5d4f --- /dev/null +++ b/api/app/agents/chat/output_rules.py @@ -0,0 +1,13 @@ +"""共用用户可见回复禁令(访谈 / 资料收集)。""" + + +def chat_output_rules() -> str: + """用户可见回复共用禁令(括号/元注释/采访腔/编造等)。""" + return ( + "**禁止**输出括号、括号内的策略/舞台说明(例如「(先接住情绪)」「(共情)」)、" + "思考过程或任何元注释——这些只存在于系统指令里,**绝不可**出现在你对用户说的话中;" + "采访腔(「我注意到」「我想了解」);重复确认对方已经说过或能推断出的信息;编造对方没说的细节。" + ) + + +__all__ = ["chat_output_rules"] diff --git a/api/app/agents/chat/profile_agent.py b/api/app/agents/chat/profile_agent.py index 9d83e5a..cc38059 100644 --- a/api/app/agents/chat/profile_agent.py +++ b/api/app/agents/chat/profile_agent.py @@ -3,7 +3,6 @@ ProfileAgent:用户资料收集 Specialist 负责提取资料、资料追问、资料收集开场白,不负责 Redis 持久化(由 Orchestrator 统一处理) """ -import json import time from typing import Any, Dict, List, Optional @@ -15,11 +14,11 @@ from app.agents.chat.prompts_profile import ( get_profile_followup_prompt, get_profile_greeting_prompt, ) +from app.agents.chat.schemas import ProfileExtractionOutput from app.core.agent_logging import agent_span, log_agent_payload, log_agent_summary from app.core.config import settings from app.core.dependencies import get_llm_provider -from app.core.json_utils import extract_json_payload -from app.core.langchain_llm import ainvoke_json_object +from app.core.llm_call import allm_json_call from app.core.logging import get_logger from app.agents.chat.reply_limits import ( nonempty_segments_or_fallback, @@ -53,6 +52,53 @@ class ProfileAgent: def __init__(self): self.llm = _get_langchain_llm() + async def _invoke_chat( + self, + messages: List[Any], + *, + max_tokens: int, + conversation_id: Optional[str], + agent_name: str, + ) -> str: + chat_llm = self.llm.bind(max_tokens=max_tokens) + llm_t0 = time.perf_counter() + with agent_span( + logger, f"{agent_name}.llm", conversation_id=conversation_id or "" + ): + response = await chat_llm.ainvoke(messages) + logger.info( + "event=chat_llm_done agent={} response_latency_ms={:.2f}", + agent_name, + (time.perf_counter() - llm_t0) * 1000, + ) + return ( + response.content if hasattr(response, "content") else str(response) + ) or "" + + async def _segments_from_response( + self, + response_text: str, + *, + max_segments: int, + max_chars_per_segment: int, + fallback: str, + ) -> List[str]: + log_agent_payload( + logger, + "ProfileAgent._segments_from_response.raw_response", + response_text, + ) + raw_list = segments_from_llm_response(response_text, max_segments=max_segments) + if not raw_list: + raw_list = [response_text.strip()] + out = truncate_chat_segments( + raw_list, + max_segments=max_segments, + max_chars_per_segment=max_chars_per_segment, + ) + segments = out if out else [response_text.strip()[:max_chars_per_segment]] + return nonempty_segments_or_fallback(segments, fallback=fallback) + async def extract_profile_from_message( self, user_message: str, @@ -81,16 +127,17 @@ class ProfileAgent: prompt = get_profile_extraction_prompt( user_message, missing_fields, recent_dialogue=recent_dialogue or None ) - content = await ainvoke_json_object( + parsed = await allm_json_call( self.llm, prompt, - max_tokens=512, + ProfileExtractionOutput, + max_tokens=settings.chat_profile_extract_max_tokens, agent="ProfileAgent.extract_profile_from_message", + fallback_factory=lambda: ProfileExtractionOutput(), ) - parsed = json.loads(extract_json_payload(content)) result = {} - if "birth_year" in parsed and parsed["birth_year"] is not None: - raw = parsed["birth_year"] + if parsed.birth_year is not None: + raw = parsed.birth_year if isinstance(raw, int) and 1900 <= raw <= 2100: result["birth_year"] = raw elif isinstance(raw, str) and raw.isdigit(): @@ -99,14 +146,14 @@ class ProfileAgent: y = 1900 + y if y >= 50 else 2000 + y if 1900 <= y <= 2100: result["birth_year"] = y - if "birth_place" in parsed and parsed["birth_place"]: - result["birth_place"] = str(parsed["birth_place"]) - if "grew_up_place" in parsed and parsed["grew_up_place"]: - result["grew_up_place"] = str(parsed["grew_up_place"]) - if "occupation" in parsed and parsed["occupation"]: - result["occupation"] = str(parsed["occupation"]) + if parsed.birth_place: + result["birth_place"] = str(parsed.birth_place) + if parsed.grew_up_place: + result["grew_up_place"] = str(parsed.grew_up_place) + if parsed.occupation: + result["occupation"] = str(parsed.occupation) return result - except (json.JSONDecodeError, Exception) as e: + except Exception as e: logger.error("提取资料信息失败: {}", e) return {} @@ -143,61 +190,33 @@ class ProfileAgent: "ProfileAgent.followup.prompt", format_history_string(messages), ) - chat_llm = self.llm.bind( - max_tokens=settings.chat_profile_followup_max_tokens - ) - llm_t0 = time.perf_counter() - with agent_span( - logger, - "ProfileAgent.followup.llm", - conversation_id=conversation_id, - ): - logger.info( - "event=chat_prompt_built agent=ProfileAgent.generate_profile_followup " - "prompt_chars={} history_pairs_total={} history_pairs_windowed={}", - _message_contents_char_count(messages), - hw.turn_total, - len(hw.window) // 2, - ) - response = await chat_llm.ainvoke(messages) + prompt_chars = _message_contents_char_count(messages) logger.info( - "event=chat_llm_done agent=ProfileAgent.generate_profile_followup " - "response_latency_ms={:.2f}", - (time.perf_counter() - llm_t0) * 1000, + "event=chat_prompt_built agent=ProfileAgent.generate_profile_followup " + "prompt_chars={} history_pairs_total={} history_pairs_windowed={}", + prompt_chars, + hw.turn_total, + len(hw.window) // 2, ) - response_text = ( - response.content if hasattr(response, "content") else str(response) + response_text = await self._invoke_chat( + messages, + max_tokens=settings.chat_profile_followup_max_tokens, + conversation_id=conversation_id, + agent_name="ProfileAgent.generate_profile_followup", ) - log_agent_payload( - logger, "ProfileAgent.followup.raw_response", response_text - ) - raw_list = segments_from_llm_response(response_text, max_segments=3) - if not raw_list: - raw_list = [response_text.strip()] - out = truncate_chat_segments( - raw_list, + segments = await self._segments_from_response( + response_text, max_segments=3, max_chars_per_segment=settings.chat_interview_max_chars_per_segment, + fallback="谢谢分享!能再告诉我一些吗?", ) log_agent_summary( logger, "ProfileAgent.followup segments={} conversation_id={}", - len(out), + len(segments), conversation_id, ) - segments = ( - out - if out - else [ - response_text.strip()[ - : settings.chat_interview_max_chars_per_segment - ] - ] - ) - return nonempty_segments_or_fallback( - segments, - fallback="谢谢分享!能再告诉我一些吗?", - ) + return segments except Exception as e: logger.error("生成资料跟进回复失败: {}", e) return ["谢谢分享!能再告诉我一些吗?"] @@ -229,61 +248,33 @@ class ProfileAgent: log_agent_payload( logger, "ProfileAgent.greeting.prompt", format_history_string(messages) ) - chat_llm = self.llm.bind( - max_tokens=settings.chat_profile_followup_max_tokens - ) - llm_t0 = time.perf_counter() - with agent_span( - logger, - "ProfileAgent.greeting.llm", - conversation_id=conversation_id, - ): - logger.info( - "event=chat_prompt_built agent=ProfileAgent.generate_profile_greeting " - "prompt_chars={} history_pairs_total={} history_pairs_windowed={}", - _message_contents_char_count(messages), - hw.turn_total, - len(hw.window) // 2, - ) - response = await chat_llm.ainvoke(messages) + prompt_chars = _message_contents_char_count(messages) logger.info( - "event=chat_llm_done agent=ProfileAgent.generate_profile_greeting " - "response_latency_ms={:.2f}", - (time.perf_counter() - llm_t0) * 1000, + "event=chat_prompt_built agent=ProfileAgent.generate_profile_greeting " + "prompt_chars={} history_pairs_total={} history_pairs_windowed={}", + prompt_chars, + hw.turn_total, + len(hw.window) // 2, ) - response_text = ( - response.content if hasattr(response, "content") else str(response) + response_text = await self._invoke_chat( + messages, + max_tokens=settings.chat_profile_followup_max_tokens, + conversation_id=conversation_id, + agent_name="ProfileAgent.generate_profile_greeting", ) - log_agent_payload( - logger, "ProfileAgent.greeting.raw_response", response_text - ) - raw_list = segments_from_llm_response(response_text, max_segments=2) - if not raw_list: - raw_list = [response_text.strip()] - out = truncate_chat_segments( - raw_list, + segments = await self._segments_from_response( + response_text, max_segments=2, max_chars_per_segment=settings.chat_interview_max_chars_per_segment, + fallback="你好!在开始之前,能告诉我你是哪一年出生的吗?", ) log_agent_summary( logger, "ProfileAgent.greeting segments={} conversation_id={}", - len(out), + len(segments), conversation_id, ) - segments = ( - out - if out - else [ - response_text.strip()[ - : settings.chat_interview_max_chars_per_segment - ] - ] - ) - return nonempty_segments_or_fallback( - segments, - fallback="你好!在开始之前,能告诉我你是哪一年出生的吗?", - ) + return segments except Exception as e: logger.error("生成资料收集开场白失败: {}", e) return [ diff --git a/api/app/agents/chat/prompts.py b/api/app/agents/chat/prompts.py index c562562..0320e34 100644 --- a/api/app/agents/chat/prompts.py +++ b/api/app/agents/chat/prompts.py @@ -2,6 +2,8 @@ Chat 模块提示词:用户资料收集 + 对话访谈 """ +from app.agents.chat.output_rules import chat_output_rules + # Profile prompts(用户资料收集) from app.agents.chat.prompts_profile import ( PROFILE_FIELD_NAMES, @@ -20,6 +22,7 @@ from app.agents.chat.prompts_conversation import ( ) __all__ = [ + "chat_output_rules", "PROFILE_FIELD_NAMES", "format_user_profile_context", "get_missing_profile_fields", diff --git a/api/app/agents/chat/prompts_conversation.py b/api/app/agents/chat/prompts_conversation.py index 99d6c1b..61fd0fe 100644 --- a/api/app/agents/chat/prompts_conversation.py +++ b/api/app/agents/chat/prompts_conversation.py @@ -19,7 +19,8 @@ from app.agents.chat.personas import ( get_opening_persona_line, normalize_interview_persona, ) -from app.agents.stage_constants import CHAT_STAGES, STAGE_DISPLAY_ZH +from app.agents.chat.output_rules import chat_output_rules +from app.agents.stage_constants import CHAT_STAGES, STAGE_DISPLAY_ZH, STAGE_ERA_HINTS from app.core.config import settings SLOT_NAME_MAP = { @@ -176,7 +177,7 @@ def get_opening_prompt( ## 格式 - 可用 [SPLIT] 分成最多 2 条;或一条里「问候 + 问题」。 -- **禁止**括号、括号内策略/旁白(如「(先接住情绪)」)、思考过程;不要替用户编回答。 +- {chat_output_rules()} 不要替用户编回答。 {style_examples} @@ -202,18 +203,7 @@ def _build_era_context(current_stage: str, user_profile_context: str) -> str: if not birth_year: return "" - stage_era_map = { - "childhood": (0, 12), - "education": (6, 22), - "career": (18, 50), - "family": (20, 50), - "belief": (30, 60), - # chapter / 防御性 key:与 belief 同档年龄参照 - "beliefs": (30, 60), - "summary": (30, 60), - } - - age_range = stage_era_map.get(current_stage, (0, 30)) + age_range = STAGE_ERA_HINTS.get(current_stage, (0, 30)) era_start = birth_year + age_range[0] era_end = birth_year + age_range[1] @@ -463,7 +453,7 @@ def get_guided_conversation_prompt( {dynamic_guidance}{uncovered_hint} ## 不要做的 -**禁止**输出括号、括号内的策略/舞台说明(例如「(先接住情绪)」「(共情)」)、思考过程或任何元注释——这些只存在于系统指令里,**绝不可**出现在你对用户说的话中;采访腔(「我注意到」「我想了解」);重复确认对方已经说过或能推断出的信息;编造对方没说的细节。 +{chat_output_rules()} 直接输出(仅自然口语,无任何括号前缀或旁白):""" diff --git a/api/app/agents/chat/prompts_profile.py b/api/app/agents/chat/prompts_profile.py index a8abb1d..14a1e31 100644 --- a/api/app/agents/chat/prompts_profile.py +++ b/api/app/agents/chat/prompts_profile.py @@ -4,6 +4,8 @@ from typing import Dict, List, Optional +from app.agents.chat.output_rules import chat_output_rules + PROFILE_FIELD_NAMES = { "birth_year": "出生年份", @@ -40,7 +42,7 @@ def get_profile_greeting_prompt(missing_fields: List[str], nickname: str = "") - - "你现在是做什么工作的呀?或者之前主要从事什么职业?" ## 严格禁止 -- 禁止输出括号注释、思考过程 +- {chat_output_rules()} - 禁止说"我需要收集信息"之类的机械话 - 禁止一次列出所有问题 @@ -71,12 +73,10 @@ def get_profile_extraction_prompt( return f"""请从以下内容中提取用户已提到的基础资料信息。{dialogue_section}用户本轮回答: "{user_message}" -**JSON 输出**:接口已启用 `response_format=json_object`(DeepSeek JSON 模式),你必须只输出一个合法 JSON 对象。 - 需要提取的字段(只提取确实在对话中出现过的): {missing_names} -请返回 JSON 格式,只包含确实提到的字段: +输出示例(只含确实提到的字段;无则 {{}}): {{ "birth_year": 1965, "birth_place": "湖南长沙", @@ -88,9 +88,7 @@ def get_profile_extraction_prompt( 1. birth_year 填整数(四位数),如"65年出生"转为 1965 2. 如果用户在任一轮说过出生地/成长地/职业等,都要提取 3. 只提取明确提到的信息,不要猜测 -4. 如果没有提取到任何信息,返回空对象 {{}} - -只返回 JSON,不要其他内容。""" +4. 如果没有提取到任何信息,返回空对象 {{}}""" def get_profile_followup_prompt( @@ -145,8 +143,7 @@ def get_profile_followup_prompt( 严格禁止: - **严禁再次询问「已知信息」中已列出的内容**(例如已知出生年份就绝不要再问哪年出生) -- 禁止输出括号注释、思考过程 -- 禁止说"我注意到""我需要了解" +- {chat_output_rules()} 回复格式:多条消息用 [SPLIT] 分隔。 直接输出你要说的话:""" diff --git a/api/app/agents/chat/schemas.py b/api/app/agents/chat/schemas.py new file mode 100644 index 0000000..fb96d81 --- /dev/null +++ b/api/app/agents/chat/schemas.py @@ -0,0 +1,19 @@ +"""LLM JSON 边界契约(Chat agents)。""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class StageDetectionOutput(BaseModel): + detected_stage: str = Field(default="", description="CHAT_STAGES key") + + +class ProfileExtractionOutput(BaseModel): + birth_year: int | str | None = None + birth_place: str | None = None + grew_up_place: str | None = None + occupation: str | None = None + + +__all__ = ["ProfileExtractionOutput", "StageDetectionOutput"] diff --git a/api/app/agents/chat/stage_detection.py b/api/app/agents/chat/stage_detection.py index 45d559e..47d47e0 100644 --- a/api/app/agents/chat/stage_detection.py +++ b/api/app/agents/chat/stage_detection.py @@ -4,9 +4,9 @@ from __future__ import annotations -import json from typing import Any, Optional +from app.agents.chat.schemas import StageDetectionOutput from app.agents.chat.stage_prompts import ( VALID_CHAT_LIFE_STAGES, get_chat_stage_detection_prompt, @@ -18,9 +18,8 @@ from app.agents.stage_constants import ( normalize_chat_stage, ) from app.core.config import settings -from app.core.langchain_llm import ainvoke_json_object +from app.core.llm_call import allm_json_call from app.core.logging import get_logger -from app.core.json_utils import extract_json_payload logger = get_logger(__name__) @@ -51,6 +50,11 @@ def keyword_fallback_primary_stage(user_message: str) -> str: return candidates[0] +def _keyword_fallback_stage(user_message: str, fb: str) -> str: + k = keyword_fallback_primary_stage(user_message) + return normalize_chat_stage(k, fb) if k else fb + + async def detect_primary_life_stage( user_message: str, current_stage: str, @@ -64,35 +68,30 @@ async def detect_primary_life_stage( """ fb = normalize_chat_stage(current_stage, "childhood") if not settings.chat_stage_detection_enabled: - k = keyword_fallback_primary_stage(user_message) - return normalize_chat_stage(k, fb) if k else fb + return _keyword_fallback_stage(user_message, fb) if skip_llm and settings.chat_stage_detection_skip_llm_on_insufficient_signal: - k = keyword_fallback_primary_stage(user_message) - return normalize_chat_stage(k, fb) if k else fb + return _keyword_fallback_stage(user_message, fb) if not llm: - k = keyword_fallback_primary_stage(user_message) - return normalize_chat_stage(k, fb) if k else fb + return _keyword_fallback_stage(user_message, fb) - try: - prompt = get_chat_stage_detection_prompt(user_message, fb) - raw = await ainvoke_json_object( - llm, - prompt, - max_tokens=settings.chat_stage_detection_max_tokens, - agent="detect_primary_life_stage", + prompt = get_chat_stage_detection_prompt(user_message, fb) + + def fallback_factory() -> StageDetectionOutput: + return StageDetectionOutput( + detected_stage=_keyword_fallback_stage(user_message, fb) ) - if not raw.strip(): - k = keyword_fallback_primary_stage(user_message) - return normalize_chat_stage(k, fb) if k else fb - parsed = json.loads(extract_json_payload(raw)) - detected = parsed.get("detected_stage", fb) - return normalize_chat_stage(str(detected) if detected is not None else "", fb) - except (json.JSONDecodeError, Exception) as e: - logger.warning("detect_primary_life_stage 解析失败,使用关键词回退: {}", e) - k = keyword_fallback_primary_stage(user_message) - return normalize_chat_stage(k, fb) if k else fb + + result = await allm_json_call( + llm, + prompt, + StageDetectionOutput, + max_tokens=settings.chat_stage_detection_max_tokens, + agent="detect_primary_life_stage", + fallback_factory=fallback_factory, + ) + return normalize_chat_stage(result.detected_stage, fb) def life_stage_display_name(stage: str) -> str: diff --git a/api/app/agents/chat/stage_prompts.py b/api/app/agents/chat/stage_prompts.py index 145af12..62716c9 100644 --- a/api/app/agents/chat/stage_prompts.py +++ b/api/app/agents/chat/stage_prompts.py @@ -26,8 +26,7 @@ def get_chat_stage_detection_prompt(user_message: str, current_stage: str) -> st 用户话语: "{user_message}" -**JSON 输出**:只输出一个合法 JSON 对象,不要 markdown 或其它文字,例如: -{{"detected_stage":"education"}} +输出形状示例:{{"detected_stage":"education"}} 规则: 1. 根据**本轮**与人生故事相关的实质内容判断主阶段;不要因系统当前阶段而强行归类。 @@ -37,5 +36,4 @@ def get_chat_stage_detection_prompt(user_message: str, current_stage: str) -> st 5. 若主要是价值观、信念、人生感悟、遗憾与骄傲等 → belief。 6. 若主要是童年、幼年成长环境、小时候 → childhood。 7. 若本轮**没有**任何与人生经历相关的实质内容(纯寒暄、谢谢、指令、语气词),则 detected_stage 取 **{current_stage}**(保持不动)。 - -只返回 JSON。""" +""" diff --git a/api/app/agents/memoir/batch_phase1_prep.py b/api/app/agents/memoir/batch_phase1_prep.py index 47e6f98..5d855bc 100644 --- a/api/app/agents/memoir/batch_phase1_prep.py +++ b/api/app/agents/memoir/batch_phase1_prep.py @@ -4,30 +4,22 @@ Phase1 批处理:一次 LLM 调用完成多段的抽取 + 章节分类(与 from __future__ import annotations -import json from dataclasses import dataclass from typing import Any, Dict, List from app.agents.memoir.prompts import get_batch_memoir_phase1_prep_prompt +from app.agents.memoir.schemas import BatchPhase1LLMOutput from app.agents.state_schema import MemoirStateSchema +from app.agents.stage_constants import STAGE_SLOT_KEYS from app.core.config import settings -from app.core.json_utils import extract_json_payload -from app.core.langchain_llm import invoke_json_object +from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger from app.features.conversation.models import Segment logger = get_logger(__name__) STAGE_ALLOWED_SLOTS: Dict[str, frozenset[str]] = { - "childhood": frozenset( - {"place", "people", "daily_life", "emotion", "turning_event"} - ), - "education": frozenset({"school", "city", "motivation", "challenge", "change"}), - "career": frozenset({"job", "environment", "decision", "pressure", "growth"}), - "family": frozenset( - {"relationship", "conflict", "support", "responsibility", "change"} - ), - "belief": frozenset({"value", "regret", "pride", "lesson"}), + k: frozenset(v) for k, v in STAGE_SLOT_KEYS.items() } @@ -73,32 +65,35 @@ def run_batch_phase1_prep( slots_snapshot=_slots_snapshot(state), segment_items=items, ) - raw = invoke_json_object( - llm, - prompt, - max_tokens=int(settings.memoir_phase1_batch_llm_max_tokens), - agent="BatchPhase1Prep.run", - ) - parsed = json.loads(extract_json_payload(raw)) - rows = parsed.get("segments") or [] - if not isinstance(rows, list): - raise ValueError("batch phase1: segments must be a list") + try: + parsed = llm_json_call( + llm, + prompt, + BatchPhase1LLMOutput, + max_tokens=int(settings.memoir_phase1_batch_llm_max_tokens), + agent="BatchPhase1Prep.run", + ) + except LLMCallError as e: + logger.warning("batch phase1 LLM 解析失败: {}", e) + raise ValueError("batch phase1: llm parse failed") from e + + rows = parsed.segments + if not rows: + raise ValueError("batch phase1: segments must be a non-empty list") by_id: Dict[str, BatchPhase1SegmentRow] = {} for row in rows: - if not isinstance(row, dict): - continue - sid = str(row.get("id", "")).strip() + sid = str(row.id).strip() if not sid: continue - ds = str(row.get("detected_stage", "") or "").strip().lower() - slots_raw = row.get("slots") or {} - slots: Dict[str, str] = {} - if isinstance(slots_raw, dict): - for k, v in slots_raw.items(): - if k and isinstance(k, str): - slots[k] = v if isinstance(v, str) else str(v) - cat_raw = str(row.get("chapter_category", row.get("category", "")) or "") + ds = str(row.detected_stage or "").strip().lower() + slots_raw = row.slots or {} + slots = { + k: v if isinstance(v, str) else str(v) + for k, v in slots_raw.items() + if k and isinstance(k, str) + } + cat_raw = str(row.chapter_category or "") by_id[sid] = BatchPhase1SegmentRow( detected_stage=ds or (state.current_stage or "childhood"), slots=slots, diff --git a/api/app/agents/memoir/classification_agent.py b/api/app/agents/memoir/classification_agent.py index c281cc2..f04a519 100644 --- a/api/app/agents/memoir/classification_agent.py +++ b/api/app/agents/memoir/classification_agent.py @@ -14,13 +14,15 @@ from dataclasses import dataclass from typing import Any from app.agents.memoir.prompts import get_chapter_classification_json_prompt +from app.agents.memoir.schemas import ClassificationOutput from app.agents.stage_constants import ( CHAPTER_CATEGORIES, STAGE_KEYWORD_WEIGHTS, STAGE_TO_DEFAULT_CATEGORY, ) +from app.core.config import settings from app.core.json_utils import extract_json_payload -from app.core.langchain_llm import invoke_json_object +from app.core.llm_call import llm_json_call from app.core.logging import get_logger logger = get_logger(__name__) @@ -135,13 +137,14 @@ class ClassificationAgent: if llm: try: prompt = get_chapter_classification_json_prompt(text) - raw = invoke_json_object( + out = llm_json_call( llm, prompt, - max_tokens=256, + ClassificationOutput, + max_tokens=settings.memoir_classification_max_tokens, agent="ClassificationAgent.classify", ) - category = _parse_category_from_llm_response(raw) + category = _normalize_llm_category(out.category) if category == "none": logger.info( "event=chapter_classification_summary_fallback reason=llm_none " diff --git a/api/app/agents/memoir/extraction_agent.py b/api/app/agents/memoir/extraction_agent.py index f0db69d..956369e 100644 --- a/api/app/agents/memoir/extraction_agent.py +++ b/api/app/agents/memoir/extraction_agent.py @@ -5,15 +5,15 @@ ExtractionAgent:从用户消息中提取 5-stage 状态与 slots。 from __future__ import annotations -import json from dataclasses import dataclass from typing import Any, Dict from app.agents.memoir.prompts import get_state_extraction_prompt +from app.agents.memoir.schemas import StateExtractionOutput from app.agents.stage_constants import normalize_chat_stage -from app.core.langchain_llm import invoke_json_object +from app.core.config import settings +from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger -from app.core.json_utils import extract_json_payload logger = get_logger(__name__) @@ -57,14 +57,14 @@ class ExtractionAgent: for k, v in (stage_slots or {}).items() }, ) - raw = invoke_json_object( + parsed = llm_json_call( llm, prompt, - max_tokens=1024, + StateExtractionOutput, + max_tokens=settings.memoir_extraction_max_tokens, agent="ExtractionAgent.extract", ) - parsed = json.loads(extract_json_payload(raw)) - raw_slots = parsed.get("slots", {}) or {} + raw_slots = parsed.slots or {} extracted_slots = { k: v if isinstance(v, str) else str(v) for k, v in raw_slots.items() } @@ -74,12 +74,12 @@ class ExtractionAgent: current_stage, fallback=current_stage ) else: - raw_detected = parsed.get("detected_stage", current_stage) + raw_detected = parsed.detected_stage or current_stage detected_stage = normalize_chat_stage( str(raw_detected) if raw_detected is not None else None, fallback=current_stage, ) - except (json.JSONDecodeError, Exception) as e: + except LLMCallError as e: logger.warning("ExtractionAgent LLM 解析失败: {}", e) return ExtractionResult(detected_stage=detected_stage, slots=extracted_slots) diff --git a/api/app/agents/memoir/fidelity_check_agent.py b/api/app/agents/memoir/fidelity_check_agent.py index 6c1f7b6..29d9a94 100644 --- a/api/app/agents/memoir/fidelity_check_agent.py +++ b/api/app/agents/memoir/fidelity_check_agent.py @@ -6,14 +6,13 @@ FidelityCheckAgent:比较「用户口述」与叙事 JSON 输出,判定是 from __future__ import annotations -import json import re from typing import Any +from app.agents.memoir.schemas import FidelityOutput from app.core.config import settings -from app.core.langchain_llm import invoke_json_object +from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger -from app.core.json_utils import extract_json_payload logger = get_logger(__name__) @@ -86,12 +85,8 @@ class FidelityCheckAgent: 判断:生成稿是否出现**既不在本轮口述、也不在已有正文**的具体新实体或虚构细节? 若内容可归因于上述两个来源的合理书面化整理,pass=true。 -**JSON 输出**:只输出一个合法 JSON 对象。 -{{"pass": true, "reason": null}} -或 -{{"pass": false, "reason": "一句话说明"}} - -只输出 JSON,不要其它文字。""" +输出形状示例: +{{"pass": true, "reason": null}} 或 {{"pass": false, "reason": "一句话说明"}}""" else: prompt = f"""你是事实核对员。比较用户口述与模型生成的叙事。 @@ -106,28 +101,24 @@ class FidelityCheckAgent: 判断:生成稿是否出现口述中**明显没有**的具体新实体或虚构细节? 若仅为口述的书面化整理(含文学性改写、情感渲染、过渡衔接),pass=true。 -**JSON 输出**:只输出一个合法 JSON 对象。 -{{"pass": true, "reason": null}} -或 -{{"pass": false, "reason": "一句话说明"}} - -只输出 JSON,不要其它文字。""" +输出形状示例: +{{"pass": true, "reason": null}} 或 {{"pass": false, "reason": "一句话说明"}}""" try: - raw = invoke_json_object( + out = llm_json_call( llm, prompt, + FidelityOutput, max_tokens=settings.memoir_fidelity_check_max_tokens, agent="FidelityCheckAgent.passes", ) - data = json.loads(extract_json_payload(raw)) - ok = bool(data.get("pass", True)) + ok = bool(out.pass_) if not ok: logger.warning( "event=fidelity_check_fail reason={}", - (data.get("reason") or "")[:200], + (out.reason or "")[:200], ) return ok - except Exception as e: + except LLMCallError as e: logger.warning("FidelityCheckAgent 解析失败: {}", e) if is_append or settings.memoir_fidelity_fail_open_on_parse_error: logger.info("event=fidelity_parse_fail_open is_append={}", is_append) diff --git a/api/app/agents/memoir/narrative_agent.py b/api/app/agents/memoir/narrative_agent.py index 32a6e79..f34131c 100644 --- a/api/app/agents/memoir/narrative_agent.py +++ b/api/app/agents/memoir/narrative_agent.py @@ -5,7 +5,6 @@ NarrativeAgent:生成创意标题和叙事改写。 from __future__ import annotations -import json from typing import Any, Dict, Optional from app.agents.stage_constants import CHAPTER_CATEGORIES @@ -14,9 +13,11 @@ from app.agents.memoir.prompts import ( get_narrative_json_prompt, get_narrative_merge_json_prompt, ) +from app.agents.memoir.schemas import MemoirTitleOutput +from app.core.config import settings from app.core.langchain_llm import invoke_json_object +from app.core.llm_call import llm_json_call from app.core.logging import get_logger -from app.core.json_utils import extract_json_payload logger = get_logger(__name__) @@ -44,17 +45,23 @@ class NarrativeAgent: user_profile=user_profile, birth_year=birth_year, ) - raw = invoke_json_object( + default_title = f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆" + + def _title_fallback() -> MemoirTitleOutput: + return MemoirTitleOutput(title=default_title) + + out = llm_json_call( llm, prompt, - max_tokens=256, + MemoirTitleOutput, + max_tokens=settings.memoir_title_max_tokens, agent="NarrativeAgent.generate_title", + fallback_factory=_title_fallback, ) - data = json.loads(extract_json_payload(raw)) - title = (data.get("title") or "").strip() if isinstance(data, dict) else "" + title = (out.title or "").strip() if title: return title.strip('"') - return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆" + return default_title except Exception as e: logger.warning("NarrativeAgent 生成标题失败: {}", e) return f"{CHAPTER_CATEGORIES.get(stage, stage)} 回忆" @@ -100,7 +107,7 @@ class NarrativeAgent: background_voice=background_voice, occupation=occupation, ) - max_tokens = 8192 + max_tokens = int(settings.memoir_narrative_merge_max_tokens) agent_name = "NarrativeAgent.generate_narrative_merge" else: prompt = get_narrative_json_prompt( @@ -113,7 +120,7 @@ class NarrativeAgent: background_voice=background_voice, occupation=occupation, ) - max_tokens = 4096 + max_tokens = int(settings.memoir_narrative_max_tokens) agent_name = "NarrativeAgent.generate_narrative" return invoke_json_object( llm, diff --git a/api/app/agents/memoir/prompts.py b/api/app/agents/memoir/prompts.py index 474d26a..d652e53 100644 --- a/api/app/agents/memoir/prompts.py +++ b/api/app/agents/memoir/prompts.py @@ -9,6 +9,7 @@ from typing import Optional from app.agents.chat.background_voice import get_background_voice_narrative_block from app.agents.chat.occupation_context import get_occupation_narrative_hint +from app.agents.stage_constants import STAGE_ERA_HINTS, STAGE_SLOT_KEYS from app.features.memory.evidence_format import ( dedupe_evidence_chunk_rows, format_evidence_chunks_for_prompt, @@ -119,9 +120,8 @@ childhood, education, career_early, career_achievement, career_challenge, family 对话内容: {segments_text} -**JSON 输出**:`response_format=json_object`,只输出: +输出形状(仅此对象): {{"category": "childhood|education|career_early|career_achievement|career_challenge|family|beliefs|summary|none"}} -不要其它文字。 若你返回 **none**,服务端会将本段映射到 **summary** 章节并仍写入回忆录正文(不落库丢弃)。""" @@ -131,21 +131,13 @@ def get_state_extraction_prompt( ) -> str: """抽取结构化信息并判断阶段""" slot_keys = list(stage_slots.keys()) - all_stage_slots = { - "childhood": ["place", "people", "daily_life", "emotion", "turning_event"], - "education": ["school", "city", "motivation", "challenge", "change"], - "career": ["job", "environment", "decision", "pressure", "growth"], - "family": ["relationship", "conflict", "support", "responsibility", "change"], - "belief": ["value", "regret", "pride", "lesson"], - } + all_stage_slots = {k: list(v) for k, v in STAGE_SLOT_KEYS.items()} return f"""你是回忆录访谈信息抽取助手。从用户话语中提取结构化信息,判断用户实际在谈论哪个人生阶段。 只提取口述中确有依据的片段,不得编造或推测。 你需要从用户话语中**先提炼与人生经历相关的核心内容**,然后抽取结构化信息(slots 仅填口述中确有依据的片段)。 -**JSON 输出**:接口已启用 `response_format=json_object`,你必须只输出一个合法 JSON 对象,不要 markdown 代码块或其它文字。 - 系统当前跟踪的阶段:{current_stage} 该阶段可填 slots:{slot_keys} @@ -189,6 +181,10 @@ def get_batch_memoir_phase1_prep_prompt( for sid, text in segment_items: lines.append(f"- id={sid}\n 文本:{text}") + slot_lines = "\n".join( + f"- {st}: {', '.join(keys)}" for st, keys in STAGE_SLOT_KEYS.items() + ) + return f"""你是回忆录访谈助手。下面有多段用户口述(按时间顺序),请**逐段**完成: 1)信息抽取(slots、detected_stage)——规则与单段抽取相同; 2)章节分类(chapter_category)——规则与单段分类相同。 @@ -199,11 +195,7 @@ def get_batch_memoir_phase1_prep_prompt( detected_stage 仅允许:childhood | education | career | family | belief slots 的 key 必须属于该 detected_stage 对应集合: -- childhood: place, people, daily_life, emotion, turning_event -- education: school, city, motivation, challenge, change -- career: job, environment, decision, pressure, growth -- family: relationship, conflict, support, responsibility, change -- belief: value, regret, pride, lesson +{slot_lines} chapter_category 仅允许:childhood | education | career_early | career_achievement | career_challenge | family | beliefs | summary | **none** (不足以成篇的档案点/纯寒暄 → **none**;与单段分类一致。) @@ -211,7 +203,7 @@ chapter_category 仅允许:childhood | education | career_early | career_achie 逐段任务(按下列列表顺序,**segments 数组须覆盖每一行 id,且顺序一致**): {chr(10).join(lines)} -**JSON 输出**:只输出一个合法 JSON 对象,不要 markdown。格式: +输出 JSON 对象(无 markdown),格式: {{ "segments": [ {{ @@ -228,22 +220,10 @@ chapter_category 仅允许:childhood | education | career_early | career_achie def _build_age_hint(stage: str, birth_year: Optional[int] = None) -> str: - """根据人生阶段和出生年份推算大致年龄区间""" + """根据人生阶段和出生年份推算大致年龄区间(`STAGE_ERA_HINTS`,仅作提示)。""" if not birth_year: return "" - stage_age_ranges = { - "childhood": (0, 12), - "education": (6, 22), - "career": (18, 60), - "career_early": (18, 30), - "career_achievement": (25, 55), - "career_challenge": (20, 55), - "family": (20, 60), - "belief": (30, 70), - "beliefs": (30, 70), - "summary": (50, 80), - } - age_range = stage_age_ranges.get(stage) + age_range = STAGE_ERA_HINTS.get(stage) if not age_range: return "" year_start = birth_year + age_range[0] @@ -298,9 +278,8 @@ def get_creative_title_json_prompt( ) return ( base.rstrip() - + "\n\n**JSON 输出**:`response_format=json_object`,只输出:" + + "\n\n输出示例(仅此 JSON 对象):" + '\n{"title":"完整标题一行(含时间标注 · 正文格式)"}\n' - + "不要其它文字。" ) @@ -331,8 +310,7 @@ def get_narrative_json_prompt( return f"""{get_narrative_editor_system_prompt(background_voice=background_voice, occupation=occupation)} -请将「本段用户口述」改写为第一人称书面叙述,并输出 **纯 JSON**,不要包含任何其他文字或 markdown 代码块。 -**JSON 输出**:接口已启用 `response_format=json_object`(与 DeepSeek JSON 模式一致),只输出一个合法 JSON 对象。 +请将「本段用户口述」改写为第一人称书面叙述,并输出 **纯 JSON**(无 markdown 围栏)。 阶段:{stage} 可用信息(slots):{slots}{profile_section}{time_section} @@ -343,7 +321,7 @@ def get_narrative_json_prompt( ## 要求 1. **格式与输出**:只输出 JSON;第一人称;不使用 `#`、`##`、表格;`content` 仅含正文。 -2. **事实与取材**:(须遵守系统说明中的事实边界规则 1–4)。只展开「本段用户口述」;若有参考摘录区,不得把摘录中的具体事实写成本轮亲历;过滤语气词与寒暄;不重复已有故事全文;本批同一主题/事件链;段落数量与长度随材料,禁止为凑字数编造。 +2. **事实与取材**:遵守事实边界,不补写未给出的细节。只展开「本段用户口述」;若有参考摘录区,不得把摘录中的具体事实写成本轮亲历;过滤语气词与寒暄;不重复已有故事全文;本批同一主题/事件链;段落数量与长度随材料,禁止为凑字数编造。 3. **不推断结局**:用户未明确说结果(是否录取、是否被选中等)时,不要凭常识补全为确定结论。 ## 输出格式(严格 JSON) @@ -409,8 +387,6 @@ def get_narrative_merge_json_prompt( 你正在**扩写并重组**一则已有回忆录故事:必须把「已有故事」中的事实全部保留在输出中(可合并重复表述、调整语序),并融入「本段用户口述」中的新事实;按**事件发生的时间顺序**排列段落(早→晚);禁止丢弃未矛盾的旧内容。 -**JSON 输出**:接口已启用 `response_format=json_object`,只输出一个合法 JSON 对象,不要 markdown 代码块。 - 阶段:{stage} 可用信息(slots):{slots}{profile_section}{time_section} @@ -420,7 +396,7 @@ def get_narrative_merge_json_prompt( ## 要求 1. **全文输出**:`paragraphs` 须为重组后的**完整故事正文**(非仅本段)。 -2. **事实边界**:(须遵守系统说明中的事实边界规则 1–4)。不得新增「已有」或「本段」未出现的人名、地点、时间、对话、数字;第一人称、优雅书面语须符合上文传记作家文体说明;不用 `#`、`##`、表格。 +2. **事实边界**:遵守事实边界,不补写未给出的细节。不得新增「已有」或「本段」未出现的人名、地点、时间、对话、数字;第一人称、优雅书面语须符合上文传记作家文体说明;不用 `#`、`##`、表格。 3. 若本段与旧文完全重复或无新信息,可输出与旧文等价重组的正文(不得无故缩短到明显少于旧文)。 4. **不推断结局**:本段未明确结果时,不要补全落选/未通过等确定说法,除非旧文中已有同一事实。 @@ -485,8 +461,6 @@ def get_story_route_prompt( merge_hint = story_route_merge_hint_for_category(chapter_category) return f"""你是回忆录编辑助手。根据本批用户口述与【候选故事】决定 append_story 或 new_story。 -**JSON 输出**:接口已启用 `response_format=json_object`,只输出下面 schema 的一个合法 JSON 对象,不要 markdown。 - ## 两层决策标准(必须先在心里过一遍) 1. **主题连续性信号**:价值观、关系模式、长期总结、同一反思维度;口述是否像在**同一主题容器**里加厚? 2. **事件切换信号**:是否出现**新人物组合、新地点、新时间段、新事件因果链**,与候选正文明显是**另一段经历**? @@ -535,8 +509,6 @@ def get_story_batch_plan_prompt( merge_hint = story_route_merge_hint_for_category(chapter_category) return f"""你是回忆录编辑助手。下面同一章节类别下有一批**按时间顺序**的用户口述片段(每段有 id 与文本)。 -**JSON 输出**:接口已启用 `response_format=json_object`,只输出下面 schema 的一个合法 JSON 对象,不要 markdown。 - ## 两层决策标准(每一块都要应用) 1. **主题连续性信号**:价值观、关系模式、长期总结、同一反思维度。 2. **事件切换信号**:新人物组合、新地点、新时间段、新事件因果链。 diff --git a/api/app/agents/memoir/schemas.py b/api/app/agents/memoir/schemas.py new file mode 100644 index 0000000..5adcf4e --- /dev/null +++ b/api/app/agents/memoir/schemas.py @@ -0,0 +1,53 @@ +"""LLM JSON 边界契约(Memoir agents)。""" + +from __future__ import annotations + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field + + +class ClassificationOutput(BaseModel): + category: str = "" + + +class MemoirTitleOutput(BaseModel): + title: str = "" + + +class FidelityOutput(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + pass_: bool = Field(default=True, alias="pass") + reason: str | None = None + + +class StateExtractionOutput(BaseModel): + detected_stage: str = "" + slots: dict[str, str] = Field(default_factory=dict) + emotion: str | None = None + is_new_chapter: bool | None = None + + +class BatchPhase1SegmentRowOut(BaseModel): + id: str + detected_stage: str = "" + slots: dict[str, str] = Field(default_factory=dict) + chapter_category: str = Field( + default="", + validation_alias=AliasChoices("chapter_category", "category"), + ) + + model_config = ConfigDict(extra="ignore", populate_by_name=True) + + +class BatchPhase1LLMOutput(BaseModel): + segments: list[BatchPhase1SegmentRowOut] + + +__all__ = [ + "BatchPhase1LLMOutput", + "BatchPhase1SegmentRowOut", + "ClassificationOutput", + "FidelityOutput", + "MemoirTitleOutput", + "StateExtractionOutput", +] diff --git a/api/app/agents/memoir/story_route_agent.py b/api/app/agents/memoir/story_route_agent.py index 7bfc013..632e292 100644 --- a/api/app/agents/memoir/story_route_agent.py +++ b/api/app/agents/memoir/story_route_agent.py @@ -15,7 +15,7 @@ from app.agents.memoir.prompts import ( ) from app.agents.memoir.story_route_payload import build_route_candidate_json from app.core.config import settings -from app.core.langchain_llm import invoke_json_object +from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger from app.features.story.models import Story @@ -132,23 +132,23 @@ class StoryRouteAgent: batch_transcript=batch_transcript, candidate_stories_json=payload, ) - try: - raw = invoke_json_object( - llm, - prompt, - max_tokens=1024, - agent="StoryRouteAgent.decide", - ).strip() - data = json.loads(raw) - decision = StoryRouteDecision.model_validate(data) - except Exception as e: - logger.warning("StoryRouteAgent 解析失败: {}", e) + + def _decide_fallback() -> StoryRouteDecision: return StoryRouteDecision( decision="new_story", new_story_title=None, reason="parse_error", ) + decision = llm_json_call( + llm, + prompt, + StoryRouteDecision, + max_tokens=settings.memoir_story_route_max_tokens, + agent="StoryRouteAgent.decide", + fallback_factory=_decide_fallback, + ) + if decision.decision == "append_story": tid = decision.target_story_id if not tid or tid not in valid_story_ids: @@ -188,15 +188,14 @@ class StoryRouteAgent: candidate_stories_json=payload, ) try: - raw = invoke_json_object( + plan = llm_json_call( llm, prompt, - max_tokens=4096, + StoryBatchPlan, + max_tokens=settings.memoir_story_batch_plan_max_tokens, agent="StoryRouteAgent.plan_batch", - ).strip() - data = json.loads(raw) - plan = StoryBatchPlan.model_validate(data) - except Exception as e: + ) + except LLMCallError as e: logger.warning("StoryRouteAgent.plan_batch 解析失败: {}", e) return None diff --git a/api/app/agents/stage_constants.py b/api/app/agents/stage_constants.py index 4e86cba..0ff3001 100644 --- a/api/app/agents/stage_constants.py +++ b/api/app/agents/stage_constants.py @@ -59,6 +59,32 @@ CHAPTER_ORDER = [ "summary", ] +# 访谈阶段 slot 名称(与 `state_schema.default_slots` 必须一致;提示词由此生成) +STAGE_SLOT_KEYS: dict[str, tuple[str, ...]] = { + "childhood": ("place", "people", "daily_life", "emotion", "turning_event"), + "education": ("school", "city", "motivation", "challenge", "change"), + "career": ("job", "environment", "decision", "pressure", "growth"), + "family": ("relationship", "conflict", "support", "responsibility", "change"), + "belief": ("value", "regret", "pride", "lesson"), +} + +# 人生阶段 / 章节类目的年龄参照(仅用于 prompt 时间提示;非业务校验) +STAGE_ERA_HINTS: dict[str, tuple[int, int]] = { + "childhood": (0, 12), + "education": (6, 22), + "career": (18, 50), + "career_early": (18, 30), + "career_achievement": (25, 55), + "career_challenge": (20, 55), + "family": (20, 55), + "belief": (30, 70), + "beliefs": (30, 70), + "summary": (50, 80), +} + +# Naming: `belief` = canonical CHAT_STAGES key; `beliefs` = chapter category (CHAPTER_*). +# 两者之间仅在 normalize_chat_stage / _CHAT_STAGE_SYNONYMS 映射,勿在其它处随意互转。 + # career / career_early 与 belief / beliefs 共用序号:chat-stage 别名与 chapter category 兼容映射 STAGE_TO_ORDER = { "childhood": 0, diff --git a/api/app/core/config.py b/api/app/core/config.py index 8ee8417..3f4b6ba 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -124,6 +124,16 @@ class Settings(BaseSettings): # Memoir Phase1:多 segment 一批一次 LLM 完成抽取+章节分类(失败回退逐段);单段且关时仍逐段 memoir_phase1_batch_llm_enabled: bool = False memoir_phase1_batch_llm_max_tokens: int = Field(default=4096, ge=512, le=32_768) + # Memoir agents:`invoke_json_object` / `llm_json_call` 的 max_tokens(原硬编码迁至配置) + memoir_extraction_max_tokens: int = Field(default=1024, ge=64, le=8192) + memoir_classification_max_tokens: int = Field(default=256, ge=32, le=4096) + memoir_narrative_max_tokens: int = Field(default=4096, ge=256, le=32_768) + memoir_narrative_merge_max_tokens: int = Field(default=8192, ge=256, le=64_000) + memoir_title_max_tokens: int = Field(default=256, ge=32, le=4096) + memoir_story_route_max_tokens: int = Field(default=1024, ge=64, le=8192) + memoir_story_batch_plan_max_tokens: int = Field(default=4096, ge=256, le=32_768) + # 资料抽取(ProfileAgent JSON 模式) + chat_profile_extract_max_tokens: int = Field(default=512, ge=64, le=4096) # ── ASR ─────────────────────────────────────────────────── asr_provider: str = "whisper" diff --git a/api/app/core/langchain_llm.py b/api/app/core/langchain_llm.py index d1ecbc9..6c09ba7 100644 --- a/api/app/core/langchain_llm.py +++ b/api/app/core/langchain_llm.py @@ -20,6 +20,23 @@ from app.core.logging import get_logger logger = get_logger(__name__) +# OpenAI / DeepSeek:使用 response_format=json_object 时,prompt 须含子串「json」 +# (见 DeepSeek JSON Output 指南)。 +_JSON_OBJECT_PROMPT_SUFFIX = ( + "\n\n【JSON】只输出一个合法 JSON 对象,不要其它说明文字或 markdown。" +) + + +def ensure_json_object_prompt_has_json_keyword(prompt: str) -> str: + """ + 若整段 prompt 中未出现 ``json``(大小写不敏感),追加一行合规提示。 + 供所有 ``response_format: json_object`` 调用在发请求前统一处理。 + """ + p = prompt or "" + if "json" in p.casefold(): + return p + return f"{p.rstrip()}{_JSON_OBJECT_PROMPT_SUFFIX}" + def bind_json_object_mode(llm: Any, *, max_tokens: int) -> Any: """返回绑定 `response_format=json_object` 与 `max_tokens` 的 Runnable(通常为 ChatOpenAI)。""" @@ -45,14 +62,15 @@ def invoke_json_object( 同步调用 JSON object 模式;空 content 时可选重试一次(缓解 DeepSeek 偶发空输出)。 仅依赖 bind_json_object_mode,不引用 features。 """ + prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt) bound = bind_json_object_mode(llm, max_tokens=max_tokens) tag = agent or "json_object" - sha = _prompt_sha12(prompt) + sha = _prompt_sha12(prompt_for_api) attempts = 2 if retry_empty else 1 t0 = time.perf_counter() last_content = "" for attempt in range(attempts): - response = bound.invoke(prompt) + response = bound.invoke(prompt_for_api) content = (getattr(response, "content", None) or "").strip() last_content = content if content: @@ -63,7 +81,7 @@ def invoke_json_object( sha, ) _log_json_object_done( - tag, sha, prompt, content, attempt + 1, t0, success=True + tag, sha, prompt_for_api, content, attempt + 1, t0, success=True ) return content if attempt == 0 and retry_empty: @@ -74,7 +92,9 @@ def invoke_json_object( sha, ) logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha) - _log_json_object_done(tag, sha, prompt, last_content, attempts, t0, success=False) + _log_json_object_done( + tag, sha, prompt_for_api, last_content, attempts, t0, success=False + ) return "" @@ -87,14 +107,15 @@ async def ainvoke_json_object( retry_empty: bool = True, ) -> str: """异步版 `invoke_json_object`。""" + prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt) bound = bind_json_object_mode(llm, max_tokens=max_tokens) tag = agent or "json_object" - sha = _prompt_sha12(prompt) + sha = _prompt_sha12(prompt_for_api) attempts = 2 if retry_empty else 1 t0 = time.perf_counter() last_content = "" for attempt in range(attempts): - response = await bound.ainvoke(prompt) + response = await bound.ainvoke(prompt_for_api) content = (getattr(response, "content", None) or "").strip() last_content = content if content: @@ -105,7 +126,7 @@ async def ainvoke_json_object( sha, ) _log_json_object_done( - tag, sha, prompt, content, attempt + 1, t0, success=True + tag, sha, prompt_for_api, content, attempt + 1, t0, success=True ) return content if attempt == 0 and retry_empty: @@ -116,7 +137,9 @@ async def ainvoke_json_object( sha, ) logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha) - _log_json_object_done(tag, sha, prompt, last_content, attempts, t0, success=False) + _log_json_object_done( + tag, sha, prompt_for_api, last_content, attempts, t0, success=False + ) return "" diff --git a/api/app/core/llm_call.py b/api/app/core/llm_call.py new file mode 100644 index 0000000..b14a0ca --- /dev/null +++ b/api/app/core/llm_call.py @@ -0,0 +1,402 @@ +""" +Schema-driven LLM JSON 调用:统一 bind `json_object`、空输出重试、解析校验、结构化日志。 + +`extract_json_payload` 仅在 **decode 失败时** 作为一次兼容性重试;命中时打 +`event=llm_json_compat_strip_hit` 便于后续下线该路径(见计划 Step 13;生产观测零命中后再删 compat)。 +""" + +from __future__ import annotations + +import hashlib +import json +import time +from dataclasses import dataclass +from typing import Any, Callable, Literal, TypeVar + +from pydantic import BaseModel, ValidationError + +from app.core.agent_logging import agent_verbose_enabled, log_agent_payload +from app.core.json_utils import extract_json_payload +from app.core.langchain_llm import ( + bind_json_object_mode, + ensure_json_object_prompt_has_json_keyword, +) +from app.core.logging import get_logger + +logger = get_logger(__name__) + +T = TypeVar("T", bound=BaseModel) + +ErrorKind = Literal["invoke", "decode", "validation"] + + +class LLMCallError(Exception): + """未提供 fallback_factory 且调用链失败时抛出。""" + + def __init__( + self, + kind: ErrorKind, + message: str, + *, + raw_content: str | None = None, + ) -> None: + super().__init__(message) + self.kind: ErrorKind = kind + self.raw_content: str | None = raw_content + + +@dataclass(frozen=True) +class LLMCallMeta: + agent: str + schema_name: str + max_tokens: int + duration_ms: float + attempts: int + parse_ok: bool + used_fallback: bool + error_kind: str | None + + +def _prompt_sha12(prompt: str) -> str: + return hashlib.sha256((prompt or "").encode("utf-8")).hexdigest()[:12] + + +def _invoke_raw_sync( + llm: Any, + prompt: str, + *, + max_tokens: int, + agent: str, + retry_empty: bool, +) -> tuple[str, int]: + prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt) + bound = bind_json_object_mode(llm, max_tokens=max_tokens) + tag = agent or "json_object" + sha = _prompt_sha12(prompt_for_api) + attempts = 2 if retry_empty else 1 + for attempt in range(attempts): + response = bound.invoke(prompt_for_api) + content = (getattr(response, "content", None) or "").strip() + if content: + if attempt > 0: + logger.info( + "json_object 空内容重试成功 agent={} prompt_sha12={}", + tag, + sha, + ) + return content, attempt + 1 + if attempt == 0 and retry_empty: + logger.warning( + "json_object 返回空 content,将重试 agent={} attempt={} prompt_sha12={}", + tag, + attempt, + sha, + ) + logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha) + return "", attempts + + +async def _invoke_raw_async( + llm: Any, + prompt: str, + *, + max_tokens: int, + agent: str, + retry_empty: bool, +) -> tuple[str, int]: + prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt) + bound = bind_json_object_mode(llm, max_tokens=max_tokens) + tag = agent or "json_object" + sha = _prompt_sha12(prompt_for_api) + attempts = 2 if retry_empty else 1 + for attempt in range(attempts): + response = await bound.ainvoke(prompt_for_api) + content = (getattr(response, "content", None) or "").strip() + if content: + if attempt > 0: + logger.info( + "json_object 空内容重试成功 agent={} prompt_sha12={}", + tag, + sha, + ) + return content, attempt + 1 + if attempt == 0 and retry_empty: + logger.warning( + "json_object 返回空 content,将重试 agent={} attempt={} prompt_sha12={}", + tag, + attempt, + sha, + ) + logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha) + return "", attempts + + +def _parse_and_validate( + raw: str, + schema: type[T], + *, + agent: str, +) -> T: + s = (raw or "").strip() + if not s: + raise LLMCallError( + "decode", "empty llm content for json parse", raw_content=raw + ) + + data: Any + try: + data = json.loads(s) + except json.JSONDecodeError: + stripped = extract_json_payload(s) + if stripped != s: + logger.warning( + "event=llm_json_compat_strip_hit agent={} prompt_kind=decode_retry", + agent, + ) + try: + data = json.loads(stripped) + except json.JSONDecodeError as e: + raise LLMCallError( + "decode", + f"json decode failed: {e}", + raw_content=s[:4096], + ) from e + + try: + return schema.model_validate(data) + except ValidationError as e: + raise LLMCallError( + "validation", + f"pydantic validation failed: {e}", + raw_content=s[:4096], + ) from e + + +def _emit_meta( + *, + agent: str, + schema_name: str, + max_tokens: int, + t0: float, + attempts: int, + parse_ok: bool, + used_fallback: bool, + error_kind: str | None, +) -> None: + meta = LLMCallMeta( + agent=agent, + schema_name=schema_name, + max_tokens=max_tokens, + duration_ms=(time.perf_counter() - t0) * 1000, + attempts=attempts, + parse_ok=parse_ok, + used_fallback=used_fallback, + error_kind=error_kind, + ) + logger.bind( + event="llm_json_call", + agent=meta.agent, + schema=meta.schema_name, + max_tokens=meta.max_tokens, + duration_ms=round(meta.duration_ms, 2), + attempts=meta.attempts, + parse_ok=meta.parse_ok, + used_fallback=meta.used_fallback, + error_kind=meta.error_kind, + ).info("llm_json_call_done") + + +def llm_json_call( + llm: Any, + prompt: str, + schema: type[T], + *, + max_tokens: int, + agent: str, + fallback_factory: Callable[[], T] | None = None, + retry_empty: bool = True, +) -> T: + """同步:invoke → 解析 JSON → `schema.model_validate`;失败时 `fallback_factory` 或 `LLMCallError`。""" + t0 = time.perf_counter() + schema_name = getattr(schema, "__name__", str(schema)) + attempts_used = 0 + raw = "" + + try: + raw, attempts_used = _invoke_raw_sync( + llm, + prompt, + max_tokens=max_tokens, + agent=agent, + retry_empty=retry_empty, + ) + out = _parse_and_validate(raw, schema, agent=agent) + _emit_meta( + agent=agent, + schema_name=schema_name, + max_tokens=max_tokens, + t0=t0, + attempts=attempts_used, + parse_ok=True, + used_fallback=False, + error_kind=None, + ) + if agent_verbose_enabled(): + log_agent_payload( + logger, + f"{agent}.prompt", + ensure_json_object_prompt_has_json_keyword(prompt), + ) + log_agent_payload(logger, f"{agent}.response", raw) + return out + except LLMCallError as e: + used_fb = fallback_factory is not None + _emit_meta( + agent=agent, + schema_name=schema_name, + max_tokens=max_tokens, + t0=t0, + attempts=attempts_used, + parse_ok=False, + used_fallback=used_fb, + error_kind=e.kind, + ) + if agent_verbose_enabled(): + log_agent_payload( + logger, + f"{agent}.prompt", + ensure_json_object_prompt_has_json_keyword(prompt), + ) + log_agent_payload(logger, f"{agent}.response", raw) + if fallback_factory is not None: + return fallback_factory() + raise + except Exception as e: + logger.bind(agent=agent).exception("llm_json_call invoke error: {}", e) + used_fb = fallback_factory is not None + _emit_meta( + agent=agent, + schema_name=schema_name, + max_tokens=max_tokens, + t0=t0, + attempts=attempts_used, + parse_ok=False, + used_fallback=used_fb, + error_kind="invoke", + ) + if agent_verbose_enabled(): + log_agent_payload( + logger, + f"{agent}.prompt", + ensure_json_object_prompt_has_json_keyword(prompt), + ) + log_agent_payload(logger, f"{agent}.response", raw) + if fallback_factory is not None: + return fallback_factory() + raise LLMCallError( + "invoke", + str(e), + raw_content=raw[:4096] if raw else None, + ) from e + + +async def allm_json_call( + llm: Any, + prompt: str, + schema: type[T], + *, + max_tokens: int, + agent: str, + fallback_factory: Callable[[], T] | None = None, + retry_empty: bool = True, +) -> T: + """异步版,语义与 `llm_json_call` 一致。""" + t0 = time.perf_counter() + schema_name = getattr(schema, "__name__", str(schema)) + attempts_used = 0 + raw = "" + + try: + raw, attempts_used = await _invoke_raw_async( + llm, + prompt, + max_tokens=max_tokens, + agent=agent, + retry_empty=retry_empty, + ) + out = _parse_and_validate(raw, schema, agent=agent) + _emit_meta( + agent=agent, + schema_name=schema_name, + max_tokens=max_tokens, + t0=t0, + attempts=attempts_used, + parse_ok=True, + used_fallback=False, + error_kind=None, + ) + if agent_verbose_enabled(): + log_agent_payload( + logger, + f"{agent}.prompt", + ensure_json_object_prompt_has_json_keyword(prompt), + ) + log_agent_payload(logger, f"{agent}.response", raw) + return out + except LLMCallError as e: + used_fb = fallback_factory is not None + _emit_meta( + agent=agent, + schema_name=schema_name, + max_tokens=max_tokens, + t0=t0, + attempts=attempts_used, + parse_ok=False, + used_fallback=used_fb, + error_kind=e.kind, + ) + if agent_verbose_enabled(): + log_agent_payload( + logger, + f"{agent}.prompt", + ensure_json_object_prompt_has_json_keyword(prompt), + ) + log_agent_payload(logger, f"{agent}.response", raw) + if fallback_factory is not None: + return fallback_factory() + raise + except Exception as e: + logger.bind(agent=agent).exception("allm_json_call invoke error: {}", e) + used_fb = fallback_factory is not None + _emit_meta( + agent=agent, + schema_name=schema_name, + max_tokens=max_tokens, + t0=t0, + attempts=attempts_used, + parse_ok=False, + used_fallback=used_fb, + error_kind="invoke", + ) + if agent_verbose_enabled(): + log_agent_payload( + logger, + f"{agent}.prompt", + ensure_json_object_prompt_has_json_keyword(prompt), + ) + log_agent_payload(logger, f"{agent}.response", raw) + if fallback_factory is not None: + return fallback_factory() + raise LLMCallError( + "invoke", + str(e), + raw_content=raw[:4096] if raw else None, + ) from e + + +__all__ = [ + "LLMCallError", + "LLMCallMeta", + "allm_json_call", + "llm_json_call", +] diff --git a/api/docs/llm-json-mode.md b/api/docs/llm-json-mode.md index cf7bafd..19be9cc 100644 --- a/api/docs/llm-json-mode.md +++ b/api/docs/llm-json-mode.md @@ -27,7 +27,8 @@ ## DeepSeek 官方建议对齐 1. `response_format`:已由 `bind_json_object_mode` 设置。 -2. Prompt 中应包含 **json** 字样及与解析代码一致的 **字段示例**(各 `prompts_*.py` 中维护)。 +2. Prompt 中应包含 **json** 字样及与解析代码一致的 **字段示例**(各 `prompts_*.py` 中维护)。 + **集中兜底**:[`ensure_json_object_prompt_has_json_keyword`](../app/core/langchain_llm.py) 在 `invoke_json_object` / `ainvoke_json_object` 与 `llm_json_call` / `allm_json_call` 发请求前会检测;若整段 prompt 中仍无子串 `json`(大小写不敏感),会自动追加一行中文 JSON 说明,避免 OpenAI/DeepSeek 返回 400。 3. `max_tokens`:在 `invoke_json_object(..., max_tokens=...)` 与绑定处统一传入;长叙事等场景按需调大。 4. 偶发空内容:由 `invoke_json_object` 记录并重试一次;仍失败则由各 Agent 既有 `try/except` 回退。 diff --git a/api/docs/memoir_reliability.md b/api/docs/memoir_reliability.md index 17bba68..17f2006 100644 --- a/api/docs/memoir_reliability.md +++ b/api/docs/memoir_reliability.md @@ -41,3 +41,10 @@ Targeted regressions live under `api/tests/`: - `test_fidelity_gate.py`, `test_narrative_boundary_regressions.py` - `test_memory_consistency_rules.py`, `test_memoir_idempotency.py` - `test_recompose_retry_policy.py` +- `test_llm_json_call.py`, `test_stage_slot_registry.py` + +## LLM JSON (`llm_json_call`) and compat strip + +- Standard path: `response_format=json_object` → `json.loads` → Pydantic validate. +- On decode failure only, `extract_json_payload` runs once (fence / brace strip). A hit emits **`event=llm_json_compat_strip_hit`** at WARNING. +- **Step 13 (sunset)**: observe this event in production for ~1–2 weeks; if zero hits, remove the compat branch from `app.core.llm_call` and migrate remaining callers off `extract_json_payload` for JSON-mode paths. diff --git a/api/tests/test_chat_stage_detection_gates.py b/api/tests/test_chat_stage_detection_gates.py index 6f35beb..ba9c35c 100644 --- a/api/tests/test_chat_stage_detection_gates.py +++ b/api/tests/test_chat_stage_detection_gates.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock import pytest +from app.agents.chat.schemas import StageDetectionOutput from app.agents.chat.stage_detection import detect_primary_life_stage @@ -11,9 +12,9 @@ from app.agents.chat.stage_detection import detect_primary_life_stage async def test_skip_llm_does_not_call_json_llm(monkeypatch: pytest.MonkeyPatch) -> None: called: list[int] = [] - async def _fake_ainvoke(*_a: object, **_k: object) -> str: + async def _fake_allm(*_a: object, **_k: object) -> StageDetectionOutput: called.append(1) - return '{"detected_stage": "career"}' + return StageDetectionOutput(detected_stage="career") monkeypatch.setattr( "app.agents.chat.stage_detection.settings.chat_stage_detection_enabled", @@ -25,8 +26,8 @@ async def test_skip_llm_does_not_call_json_llm(monkeypatch: pytest.MonkeyPatch) True, ) monkeypatch.setattr( - "app.agents.chat.stage_detection.ainvoke_json_object", - _fake_ainvoke, + "app.agents.chat.stage_detection.allm_json_call", + _fake_allm, ) out = await detect_primary_life_stage( "嗯", diff --git a/api/tests/test_fidelity_gate.py b/api/tests/test_fidelity_gate.py index 94dbb26..9c099ad 100644 --- a/api/tests/test_fidelity_gate.py +++ b/api/tests/test_fidelity_gate.py @@ -8,6 +8,7 @@ import pytest from app.agents.memoir.fidelity_check_agent import FidelityCheckAgent from app.core.config import settings +from app.core.llm_call import LLMCallError def test_fidelity_fail_closed_on_parse_when_not_append( @@ -18,8 +19,8 @@ def test_fidelity_fail_closed_on_parse_when_not_append( agent = FidelityCheckAgent() llm = MagicMock() with patch( - "app.agents.memoir.fidelity_check_agent.invoke_json_object", - side_effect=ValueError("simulated_bad_response"), + "app.agents.memoir.fidelity_check_agent.llm_json_call", + side_effect=LLMCallError("invoke", "simulated_bad_response"), ): assert ( agent.passes( @@ -40,8 +41,8 @@ def test_fidelity_fail_open_on_parse_when_append( agent = FidelityCheckAgent() llm = MagicMock() with patch( - "app.agents.memoir.fidelity_check_agent.invoke_json_object", - side_effect=ValueError("simulated_bad_response"), + "app.agents.memoir.fidelity_check_agent.llm_json_call", + side_effect=LLMCallError("invoke", "simulated_bad_response"), ): assert ( agent.passes( @@ -62,8 +63,8 @@ def test_fidelity_fail_open_global_flag_overrides_append( agent = FidelityCheckAgent() llm = MagicMock() with patch( - "app.agents.memoir.fidelity_check_agent.invoke_json_object", - side_effect=ValueError("simulated_bad_response"), + "app.agents.memoir.fidelity_check_agent.llm_json_call", + side_effect=LLMCallError("invoke", "simulated_bad_response"), ): assert ( agent.passes( diff --git a/api/tests/test_llm_json_call.py b/api/tests/test_llm_json_call.py new file mode 100644 index 0000000..ae8c729 --- /dev/null +++ b/api/tests/test_llm_json_call.py @@ -0,0 +1,150 @@ +"""Behavior tests for `llm_json_call` / `allm_json_call` (schema-driven LLM JSON).""" + +from __future__ import annotations + +import pytest +from langchain_core.messages import AIMessage +from pydantic import BaseModel, Field + +from app.core.langchain_llm import ensure_json_object_prompt_has_json_keyword +from app.core.llm_call import LLMCallError, allm_json_call, llm_json_call + + +class _SmallOut(BaseModel): + answer: str = Field(default="") + score: int = 0 + + +class _SyncFakeLlm: + def __init__(self, contents: list[str]) -> None: + self._contents = list(contents) + self._i = 0 + + def bind(self, **_kwargs: object): + return self + + def invoke(self, _prompt: str): + c = self._contents[min(self._i, len(self._contents) - 1)] + self._i += 1 + return AIMessage(content=c) + + async def ainvoke(self, _prompt: str): + return self.invoke(_prompt) + + +class _CapturePromptLlm: + """Records the prompt passed to ``invoke`` (after ``bind`` noop).""" + + def __init__(self, content: str) -> None: + self.content = content + self.last_prompt: str | None = None + + def bind(self, **_kwargs: object): + return self + + def invoke(self, prompt: str): + self.last_prompt = prompt + return AIMessage(content=self.content) + + async def ainvoke(self, prompt: str): + return self.invoke(prompt) + + +def test_ensure_json_object_prompt_appends_when_no_json_substring() -> None: + out = ensure_json_object_prompt_has_json_keyword("仅中文,无英文字样") + assert "仅中文" in out + assert "json" in out.casefold() + + +def test_ensure_json_object_prompt_unchanged_when_json_present() -> None: + s = "只输出 JSON:{}" + assert ensure_json_object_prompt_has_json_keyword(s) is s + + +def test_llm_json_call_injects_json_keyword_for_api_when_missing() -> None: + llm = _CapturePromptLlm('{"answer": "ok", "score": 2}') + llm_json_call( + llm, + '你是助手。输出:{"x":1}', + _SmallOut, + max_tokens=32, + agent="test_agent", + ) + assert llm.last_prompt is not None + # 用户内容里含 JSON 示例但无字母 json,仍应追加合规后缀 + assert "json" in llm.last_prompt.casefold() + + +def test_llm_json_call_success() -> None: + llm = _SyncFakeLlm(['{"answer": "ok", "score": 2}']) + out = llm_json_call( + llm, + "prompt", + _SmallOut, + max_tokens=32, + agent="test_agent", + ) + assert out.answer == "ok" + assert out.score == 2 + + +def test_llm_json_call_empty_retry_then_success() -> None: + llm = _SyncFakeLlm(["", '{"answer": "retry", "score": 0}']) + out = llm_json_call( + llm, + "p", + _SmallOut, + max_tokens=8, + agent="t", + ) + assert out.answer == "retry" + assert llm._i == 2 + + +def test_llm_json_call_compat_strip() -> None: + raw = '```json\n{"answer": "fenced", "score": 1}\n```' + llm = _SyncFakeLlm([raw]) + out = llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t") + assert out.answer == "fenced" + + +def test_llm_json_call_validation_fallback() -> None: + llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}']) + out = llm_json_call( + llm, + "p", + _SmallOut, + max_tokens=8, + agent="t", + fallback_factory=lambda: _SmallOut(answer="fb"), + ) + assert out.answer == "fb" + + +def test_llm_json_call_no_fallback_raises() -> None: + llm = _SyncFakeLlm(['{"answer": "x", "score": "not_an_int"}']) + with pytest.raises(LLMCallError) as ei: + llm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="t") + assert ei.value.kind == "validation" + + +@pytest.mark.asyncio +async def test_allm_json_call_parity_with_sync() -> None: + llm = _SyncFakeLlm(['{"answer": "async", "score": 7}']) + out = await allm_json_call(llm, "p", _SmallOut, max_tokens=8, agent="at") + assert out.answer == "async" + assert out.score == 7 + + +@pytest.mark.asyncio +async def test_allm_json_call_fallback_on_decode() -> None: + llm = _SyncFakeLlm(["not json at all {{{"]) + out = await allm_json_call( + llm, + "p", + _SmallOut, + max_tokens=8, + agent="at", + fallback_factory=lambda: _SmallOut(answer="dec"), + ) + assert out.answer == "dec" diff --git a/api/tests/test_stage_slot_registry.py b/api/tests/test_stage_slot_registry.py new file mode 100644 index 0000000..5719109 --- /dev/null +++ b/api/tests/test_stage_slot_registry.py @@ -0,0 +1,13 @@ +"""`STAGE_SLOT_KEYS` must stay aligned with `MemoirStateSchema.default_slots`.""" + +from __future__ import annotations + +from app.agents.stage_constants import STAGE_SLOT_KEYS +from app.agents.state_schema import default_slots + + +def test_stage_slot_keys_matches_default_slots() -> None: + slots = default_slots() + assert set(STAGE_SLOT_KEYS.keys()) == set(slots.keys()) + for stage, keys in STAGE_SLOT_KEYS.items(): + assert tuple(slots[stage].keys()) == keys diff --git a/api/tests/test_stage_validation.py b/api/tests/test_stage_validation.py index 7d17de8..39c471a 100644 --- a/api/tests/test_stage_validation.py +++ b/api/tests/test_stage_validation.py @@ -1,11 +1,11 @@ """阶段 / 章节归一化纯函数(VALID + normalize + chat_bucket)。""" -import json from unittest.mock import MagicMock import pytest from app.agents.memoir.extraction_agent import ExtractionAgent +from app.agents.memoir.schemas import StateExtractionOutput from app.agents.stage_constants import ( chat_bucket, normalize_chapter_category, @@ -47,10 +47,10 @@ def test_extraction_agent_normalizes_detected_stage( agent = ExtractionAgent() llm = MagicMock() monkeypatch.setattr( - "app.agents.memoir.extraction_agent.invoke_json_object", - lambda *_a, **_k: json.dumps( - {"detected_stage": "career_early", "slots": {"job": "演员"}}, - ensure_ascii=False, + "app.agents.memoir.extraction_agent.llm_json_call", + lambda *_a, **_k: StateExtractionOutput( + detected_stage="career_early", + slots={"job": "演员"}, ), ) r = agent.extract( @@ -69,10 +69,10 @@ def test_extraction_agent_invalid_detected_falls_back( agent = ExtractionAgent() llm = MagicMock() monkeypatch.setattr( - "app.agents.memoir.extraction_agent.invoke_json_object", - lambda *_a, **_k: json.dumps( - {"detected_stage": "llm_hallucination", "slots": {}}, - ensure_ascii=False, + "app.agents.memoir.extraction_agent.llm_json_call", + lambda *_a, **_k: StateExtractionOutput( + detected_stage="llm_hallucination", + slots={}, ), ) r = agent.extract( @@ -91,15 +91,12 @@ def test_extraction_agent_empty_slots_inherits_current_stage( agent = ExtractionAgent() llm = MagicMock() monkeypatch.setattr( - "app.agents.memoir.extraction_agent.invoke_json_object", - lambda *_a, **_k: json.dumps( - { - "detected_stage": "childhood", - "slots": {}, - "emotion": "neutral", - "is_new_chapter": False, - }, - ensure_ascii=False, + "app.agents.memoir.extraction_agent.llm_json_call", + lambda *_a, **_k: StateExtractionOutput( + detected_stage="childhood", + slots={}, + emotion="neutral", + is_new_chapter=False, ), ) r = agent.extract( diff --git a/api/tests/test_story_route_prompts_and_behavior.py b/api/tests/test_story_route_prompts_and_behavior.py index 8ccadc5..3f163c5 100644 --- a/api/tests/test_story_route_prompts_and_behavior.py +++ b/api/tests/test_story_route_prompts_and_behavior.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json from datetime import datetime, timezone from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -12,7 +11,7 @@ from app.agents.memoir.prompts import ( get_story_route_prompt, story_route_merge_hint_for_category, ) -from app.agents.memoir.story_route_agent import StoryRouteAgent +from app.agents.memoir.story_route_agent import StoryRouteAgent, StoryRouteDecision def test_route_prompt_beliefs_has_strong_container_and_no_uncertain_new(): @@ -52,15 +51,12 @@ def test_merge_hint_career_episodic(): def test_decide_beliefs_mock_llm_append_and_prompt_has_payload(): captured: dict[str, str] = {} - def fake_invoke(_llm, prompt: str, **_kwargs): + def fake_llm_json(_llm, prompt: str, _schema: object, **_kwargs): captured["prompt"] = prompt - return json.dumps( - { - "decision": "append_story", - "target_story_id": "story-a", - "reason": "同一价值观补充", - }, - ensure_ascii=False, + return StoryRouteDecision( + decision="append_story", + target_story_id="story-a", + reason="同一价值观补充", ) cand = SimpleNamespace( @@ -72,8 +68,8 @@ def test_decide_beliefs_mock_llm_append_and_prompt_has_payload(): chapter_links=[], ) with patch( - "app.agents.memoir.story_route_agent.invoke_json_object", - side_effect=fake_invoke, + "app.agents.memoir.story_route_agent.llm_json_call", + side_effect=fake_llm_json, ): agent = StoryRouteAgent() d = agent.decide( @@ -94,15 +90,12 @@ def test_decide_beliefs_mock_llm_append_and_prompt_has_payload(): def test_decide_career_mock_llm_new_story_and_prompt_episodic(): captured: dict[str, str] = {} - def fake_invoke(_llm, prompt: str, **_kwargs): + def fake_llm_json(_llm, prompt: str, _schema: object, **_kwargs): captured["prompt"] = prompt - return json.dumps( - { - "decision": "new_story", - "target_story_id": None, - "reason": "另一段完全不同的任职经历", - }, - ensure_ascii=False, + return StoryRouteDecision( + decision="new_story", + target_story_id=None, + reason="另一段完全不同的任职经历", ) cand = SimpleNamespace( @@ -114,8 +107,8 @@ def test_decide_career_mock_llm_new_story_and_prompt_episodic(): chapter_links=[], ) with patch( - "app.agents.memoir.story_route_agent.invoke_json_object", - side_effect=fake_invoke, + "app.agents.memoir.story_route_agent.llm_json_call", + side_effect=fake_llm_json, ): d = StoryRouteAgent().decide( chapter_category="career_achievement",