""" 与 `get_llm_provider().langchain_llm` 配合使用的 LangChain Runnable 约定。 langchain-openai 要求用顶层 `response_format` 绑定 JSON 模式,禁止对 `.bind()` 传入 `model_kwargs={"response_format": ...}`(会错误传入底层 `completions.create`)。 """ from __future__ import annotations import hashlib import time from typing import Any from app.core.agent_logging import ( agent_summary_enabled, agent_verbose_enabled, log_agent_payload, ) from app.core.llm_telemetry import infer_provider_model, langchain_invoke_span from app.core.logging import get_logger logger = get_logger(__name__) # OpenAI / DeepSeek:使用 response_format=json_object 时,prompt 须含子串「json」 # (见 DeepSeek JSON Output 指南)。 _JSON_OBJECT_PROMPT_SUFFIX = ( "\n\n【JSON】只输出一个合法 JSON 对象,不要其它说明文字或 markdown。" ) def ensure_json_object_prompt_has_json_keyword(prompt: str) -> str: """ 若整段 prompt 中未出现 ``json``(大小写不敏感),追加一行合规提示。 供所有 ``response_format: json_object`` 调用在发请求前统一处理。 """ p = prompt or "" if "json" in p.casefold(): return p return f"{p.rstrip()}{_JSON_OBJECT_PROMPT_SUFFIX}" def bind_json_object_mode(llm: Any, *, max_tokens: int) -> Any: """返回绑定 `response_format=json_object` 与 `max_tokens` 的 Runnable(通常为 ChatOpenAI)。""" return llm.bind( response_format={"type": "json_object"}, max_tokens=max_tokens, ) def _prompt_sha12(prompt: str) -> str: return hashlib.sha256((prompt or "").encode("utf-8")).hexdigest()[:12] def invoke_json_object( llm: Any, prompt: str, *, max_tokens: int, agent: str | None = None, retry_empty: bool = True, ) -> str: """ 同步调用 JSON object 模式;空 content 时可选重试一次(缓解 DeepSeek 偶发空输出)。 仅依赖 bind_json_object_mode,不引用 features。 """ prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt) bound = bind_json_object_mode(llm, max_tokens=max_tokens) tag = agent or "json_object" sha = _prompt_sha12(prompt_for_api) attempts = 2 if retry_empty else 1 t0 = time.perf_counter() provider, model = infer_provider_model(llm) last_content = "" with langchain_invoke_span( agent=tag, provider=provider, model=model, call_type="json", prompt_sha12=sha, max_tokens=max_tokens, ) as tel: for attempt in range(attempts): response = bound.invoke(prompt_for_api) tel["response"] = response content = (getattr(response, "content", None) or "").strip() last_content = content if content: if attempt > 0: logger.info( "json_object 空内容重试成功 agent={} prompt_sha12={}", tag, sha, ) tel["outcome"] = "ok" _log_json_object_done( tag, sha, prompt_for_api, content, attempt + 1, t0, success=True ) return content if attempt == 0 and retry_empty: logger.warning( "json_object 返回空 content,将重试 agent={} attempt={} prompt_sha12={}", tag, attempt, sha, ) tel["outcome"] = "error" logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha) _log_json_object_done( tag, sha, prompt_for_api, last_content, attempts, t0, success=False ) return "" async def ainvoke_json_object( llm: Any, prompt: str, *, max_tokens: int, agent: str | None = None, retry_empty: bool = True, ) -> str: """异步版 `invoke_json_object`。""" prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt) bound = bind_json_object_mode(llm, max_tokens=max_tokens) tag = agent or "json_object" sha = _prompt_sha12(prompt_for_api) attempts = 2 if retry_empty else 1 t0 = time.perf_counter() provider, model = infer_provider_model(llm) last_content = "" with langchain_invoke_span( agent=tag, provider=provider, model=model, call_type="json", prompt_sha12=sha, max_tokens=max_tokens, ) as tel: for attempt in range(attempts): response = await bound.ainvoke(prompt_for_api) tel["response"] = response content = (getattr(response, "content", None) or "").strip() last_content = content if content: if attempt > 0: logger.info( "json_object 空内容重试成功 agent={} prompt_sha12={}", tag, sha, ) tel["outcome"] = "ok" _log_json_object_done( tag, sha, prompt_for_api, content, attempt + 1, t0, success=True ) return content if attempt == 0 and retry_empty: logger.warning( "json_object 返回空 content,将重试 agent={} attempt={} prompt_sha12={}", tag, attempt, sha, ) tel["outcome"] = "error" logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha) _log_json_object_done( tag, sha, prompt_for_api, last_content, attempts, t0, success=False ) return "" def _log_json_object_done( tag: str, sha: str, prompt: str, content: str, attempts_used: int, t0: float, *, success: bool, ) -> None: ms = (time.perf_counter() - t0) * 1000 if agent_summary_enabled(): prompt_chars = len(prompt or "") logger.info( "llm_json_object agent={} prompt_sha12={} duration_ms={:.2f} " "prompt_char_count={} response_len={} attempts={} success={}", tag, sha, ms, prompt_chars, len(content or ""), attempts_used, success, ) if agent_verbose_enabled(): log_agent_payload(logger, f"{tag}.prompt", prompt) log_agent_payload(logger, f"{tag}.response", content)