403 lines
12 KiB
Python
403 lines
12 KiB
Python
|
|
"""
|
|||
|
|
Schema-driven LLM JSON 调用:统一 bind `json_object`、空输出重试、解析校验、结构化日志。
|
|||
|
|
|
|||
|
|
`extract_json_payload` 仅在 **decode 失败时** 作为一次兼容性重试;命中时打
|
|||
|
|
`event=llm_json_compat_strip_hit` 便于后续下线该路径(见计划 Step 13;生产观测零命中后再删 compat)。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import hashlib
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Any, Callable, Literal, TypeVar
|
|||
|
|
|
|||
|
|
from pydantic import BaseModel, ValidationError
|
|||
|
|
|
|||
|
|
from app.core.agent_logging import agent_verbose_enabled, log_agent_payload
|
|||
|
|
from app.core.json_utils import extract_json_payload
|
|||
|
|
from app.core.langchain_llm import (
|
|||
|
|
bind_json_object_mode,
|
|||
|
|
ensure_json_object_prompt_has_json_keyword,
|
|||
|
|
)
|
|||
|
|
from app.core.logging import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
T = TypeVar("T", bound=BaseModel)
|
|||
|
|
|
|||
|
|
ErrorKind = Literal["invoke", "decode", "validation"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMCallError(Exception):
|
|||
|
|
"""未提供 fallback_factory 且调用链失败时抛出。"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
kind: ErrorKind,
|
|||
|
|
message: str,
|
|||
|
|
*,
|
|||
|
|
raw_content: str | None = None,
|
|||
|
|
) -> None:
|
|||
|
|
super().__init__(message)
|
|||
|
|
self.kind: ErrorKind = kind
|
|||
|
|
self.raw_content: str | None = raw_content
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(frozen=True)
|
|||
|
|
class LLMCallMeta:
|
|||
|
|
agent: str
|
|||
|
|
schema_name: str
|
|||
|
|
max_tokens: int
|
|||
|
|
duration_ms: float
|
|||
|
|
attempts: int
|
|||
|
|
parse_ok: bool
|
|||
|
|
used_fallback: bool
|
|||
|
|
error_kind: str | None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _prompt_sha12(prompt: str) -> str:
|
|||
|
|
return hashlib.sha256((prompt or "").encode("utf-8")).hexdigest()[:12]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _invoke_raw_sync(
|
|||
|
|
llm: Any,
|
|||
|
|
prompt: str,
|
|||
|
|
*,
|
|||
|
|
max_tokens: int,
|
|||
|
|
agent: str,
|
|||
|
|
retry_empty: bool,
|
|||
|
|
) -> tuple[str, int]:
|
|||
|
|
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
|
|||
|
|
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
|
|||
|
|
tag = agent or "json_object"
|
|||
|
|
sha = _prompt_sha12(prompt_for_api)
|
|||
|
|
attempts = 2 if retry_empty else 1
|
|||
|
|
for attempt in range(attempts):
|
|||
|
|
response = bound.invoke(prompt_for_api)
|
|||
|
|
content = (getattr(response, "content", None) or "").strip()
|
|||
|
|
if content:
|
|||
|
|
if attempt > 0:
|
|||
|
|
logger.info(
|
|||
|
|
"json_object 空内容重试成功 agent={} prompt_sha12={}",
|
|||
|
|
tag,
|
|||
|
|
sha,
|
|||
|
|
)
|
|||
|
|
return content, attempt + 1
|
|||
|
|
if attempt == 0 and retry_empty:
|
|||
|
|
logger.warning(
|
|||
|
|
"json_object 返回空 content,将重试 agent={} attempt={} prompt_sha12={}",
|
|||
|
|
tag,
|
|||
|
|
attempt,
|
|||
|
|
sha,
|
|||
|
|
)
|
|||
|
|
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
|
|||
|
|
return "", attempts
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _invoke_raw_async(
|
|||
|
|
llm: Any,
|
|||
|
|
prompt: str,
|
|||
|
|
*,
|
|||
|
|
max_tokens: int,
|
|||
|
|
agent: str,
|
|||
|
|
retry_empty: bool,
|
|||
|
|
) -> tuple[str, int]:
|
|||
|
|
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
|
|||
|
|
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
|
|||
|
|
tag = agent or "json_object"
|
|||
|
|
sha = _prompt_sha12(prompt_for_api)
|
|||
|
|
attempts = 2 if retry_empty else 1
|
|||
|
|
for attempt in range(attempts):
|
|||
|
|
response = await bound.ainvoke(prompt_for_api)
|
|||
|
|
content = (getattr(response, "content", None) or "").strip()
|
|||
|
|
if content:
|
|||
|
|
if attempt > 0:
|
|||
|
|
logger.info(
|
|||
|
|
"json_object 空内容重试成功 agent={} prompt_sha12={}",
|
|||
|
|
tag,
|
|||
|
|
sha,
|
|||
|
|
)
|
|||
|
|
return content, attempt + 1
|
|||
|
|
if attempt == 0 and retry_empty:
|
|||
|
|
logger.warning(
|
|||
|
|
"json_object 返回空 content,将重试 agent={} attempt={} prompt_sha12={}",
|
|||
|
|
tag,
|
|||
|
|
attempt,
|
|||
|
|
sha,
|
|||
|
|
)
|
|||
|
|
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
|
|||
|
|
return "", attempts
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_and_validate(
|
|||
|
|
raw: str,
|
|||
|
|
schema: type[T],
|
|||
|
|
*,
|
|||
|
|
agent: str,
|
|||
|
|
) -> T:
|
|||
|
|
s = (raw or "").strip()
|
|||
|
|
if not s:
|
|||
|
|
raise LLMCallError(
|
|||
|
|
"decode", "empty llm content for json parse", raw_content=raw
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
data: Any
|
|||
|
|
try:
|
|||
|
|
data = json.loads(s)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
stripped = extract_json_payload(s)
|
|||
|
|
if stripped != s:
|
|||
|
|
logger.warning(
|
|||
|
|
"event=llm_json_compat_strip_hit agent={} prompt_kind=decode_retry",
|
|||
|
|
agent,
|
|||
|
|
)
|
|||
|
|
try:
|
|||
|
|
data = json.loads(stripped)
|
|||
|
|
except json.JSONDecodeError as e:
|
|||
|
|
raise LLMCallError(
|
|||
|
|
"decode",
|
|||
|
|
f"json decode failed: {e}",
|
|||
|
|
raw_content=s[:4096],
|
|||
|
|
) from e
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
return schema.model_validate(data)
|
|||
|
|
except ValidationError as e:
|
|||
|
|
raise LLMCallError(
|
|||
|
|
"validation",
|
|||
|
|
f"pydantic validation failed: {e}",
|
|||
|
|
raw_content=s[:4096],
|
|||
|
|
) from e
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _emit_meta(
|
|||
|
|
*,
|
|||
|
|
agent: str,
|
|||
|
|
schema_name: str,
|
|||
|
|
max_tokens: int,
|
|||
|
|
t0: float,
|
|||
|
|
attempts: int,
|
|||
|
|
parse_ok: bool,
|
|||
|
|
used_fallback: bool,
|
|||
|
|
error_kind: str | None,
|
|||
|
|
) -> None:
|
|||
|
|
meta = LLMCallMeta(
|
|||
|
|
agent=agent,
|
|||
|
|
schema_name=schema_name,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
duration_ms=(time.perf_counter() - t0) * 1000,
|
|||
|
|
attempts=attempts,
|
|||
|
|
parse_ok=parse_ok,
|
|||
|
|
used_fallback=used_fallback,
|
|||
|
|
error_kind=error_kind,
|
|||
|
|
)
|
|||
|
|
logger.bind(
|
|||
|
|
event="llm_json_call",
|
|||
|
|
agent=meta.agent,
|
|||
|
|
schema=meta.schema_name,
|
|||
|
|
max_tokens=meta.max_tokens,
|
|||
|
|
duration_ms=round(meta.duration_ms, 2),
|
|||
|
|
attempts=meta.attempts,
|
|||
|
|
parse_ok=meta.parse_ok,
|
|||
|
|
used_fallback=meta.used_fallback,
|
|||
|
|
error_kind=meta.error_kind,
|
|||
|
|
).info("llm_json_call_done")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def llm_json_call(
|
|||
|
|
llm: Any,
|
|||
|
|
prompt: str,
|
|||
|
|
schema: type[T],
|
|||
|
|
*,
|
|||
|
|
max_tokens: int,
|
|||
|
|
agent: str,
|
|||
|
|
fallback_factory: Callable[[], T] | None = None,
|
|||
|
|
retry_empty: bool = True,
|
|||
|
|
) -> T:
|
|||
|
|
"""同步:invoke → 解析 JSON → `schema.model_validate`;失败时 `fallback_factory` 或 `LLMCallError`。"""
|
|||
|
|
t0 = time.perf_counter()
|
|||
|
|
schema_name = getattr(schema, "__name__", str(schema))
|
|||
|
|
attempts_used = 0
|
|||
|
|
raw = ""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
raw, attempts_used = _invoke_raw_sync(
|
|||
|
|
llm,
|
|||
|
|
prompt,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
agent=agent,
|
|||
|
|
retry_empty=retry_empty,
|
|||
|
|
)
|
|||
|
|
out = _parse_and_validate(raw, schema, agent=agent)
|
|||
|
|
_emit_meta(
|
|||
|
|
agent=agent,
|
|||
|
|
schema_name=schema_name,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
t0=t0,
|
|||
|
|
attempts=attempts_used,
|
|||
|
|
parse_ok=True,
|
|||
|
|
used_fallback=False,
|
|||
|
|
error_kind=None,
|
|||
|
|
)
|
|||
|
|
if agent_verbose_enabled():
|
|||
|
|
log_agent_payload(
|
|||
|
|
logger,
|
|||
|
|
f"{agent}.prompt",
|
|||
|
|
ensure_json_object_prompt_has_json_keyword(prompt),
|
|||
|
|
)
|
|||
|
|
log_agent_payload(logger, f"{agent}.response", raw)
|
|||
|
|
return out
|
|||
|
|
except LLMCallError as e:
|
|||
|
|
used_fb = fallback_factory is not None
|
|||
|
|
_emit_meta(
|
|||
|
|
agent=agent,
|
|||
|
|
schema_name=schema_name,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
t0=t0,
|
|||
|
|
attempts=attempts_used,
|
|||
|
|
parse_ok=False,
|
|||
|
|
used_fallback=used_fb,
|
|||
|
|
error_kind=e.kind,
|
|||
|
|
)
|
|||
|
|
if agent_verbose_enabled():
|
|||
|
|
log_agent_payload(
|
|||
|
|
logger,
|
|||
|
|
f"{agent}.prompt",
|
|||
|
|
ensure_json_object_prompt_has_json_keyword(prompt),
|
|||
|
|
)
|
|||
|
|
log_agent_payload(logger, f"{agent}.response", raw)
|
|||
|
|
if fallback_factory is not None:
|
|||
|
|
return fallback_factory()
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.bind(agent=agent).exception("llm_json_call invoke error: {}", e)
|
|||
|
|
used_fb = fallback_factory is not None
|
|||
|
|
_emit_meta(
|
|||
|
|
agent=agent,
|
|||
|
|
schema_name=schema_name,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
t0=t0,
|
|||
|
|
attempts=attempts_used,
|
|||
|
|
parse_ok=False,
|
|||
|
|
used_fallback=used_fb,
|
|||
|
|
error_kind="invoke",
|
|||
|
|
)
|
|||
|
|
if agent_verbose_enabled():
|
|||
|
|
log_agent_payload(
|
|||
|
|
logger,
|
|||
|
|
f"{agent}.prompt",
|
|||
|
|
ensure_json_object_prompt_has_json_keyword(prompt),
|
|||
|
|
)
|
|||
|
|
log_agent_payload(logger, f"{agent}.response", raw)
|
|||
|
|
if fallback_factory is not None:
|
|||
|
|
return fallback_factory()
|
|||
|
|
raise LLMCallError(
|
|||
|
|
"invoke",
|
|||
|
|
str(e),
|
|||
|
|
raw_content=raw[:4096] if raw else None,
|
|||
|
|
) from e
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def allm_json_call(
|
|||
|
|
llm: Any,
|
|||
|
|
prompt: str,
|
|||
|
|
schema: type[T],
|
|||
|
|
*,
|
|||
|
|
max_tokens: int,
|
|||
|
|
agent: str,
|
|||
|
|
fallback_factory: Callable[[], T] | None = None,
|
|||
|
|
retry_empty: bool = True,
|
|||
|
|
) -> T:
|
|||
|
|
"""异步版,语义与 `llm_json_call` 一致。"""
|
|||
|
|
t0 = time.perf_counter()
|
|||
|
|
schema_name = getattr(schema, "__name__", str(schema))
|
|||
|
|
attempts_used = 0
|
|||
|
|
raw = ""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
raw, attempts_used = await _invoke_raw_async(
|
|||
|
|
llm,
|
|||
|
|
prompt,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
agent=agent,
|
|||
|
|
retry_empty=retry_empty,
|
|||
|
|
)
|
|||
|
|
out = _parse_and_validate(raw, schema, agent=agent)
|
|||
|
|
_emit_meta(
|
|||
|
|
agent=agent,
|
|||
|
|
schema_name=schema_name,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
t0=t0,
|
|||
|
|
attempts=attempts_used,
|
|||
|
|
parse_ok=True,
|
|||
|
|
used_fallback=False,
|
|||
|
|
error_kind=None,
|
|||
|
|
)
|
|||
|
|
if agent_verbose_enabled():
|
|||
|
|
log_agent_payload(
|
|||
|
|
logger,
|
|||
|
|
f"{agent}.prompt",
|
|||
|
|
ensure_json_object_prompt_has_json_keyword(prompt),
|
|||
|
|
)
|
|||
|
|
log_agent_payload(logger, f"{agent}.response", raw)
|
|||
|
|
return out
|
|||
|
|
except LLMCallError as e:
|
|||
|
|
used_fb = fallback_factory is not None
|
|||
|
|
_emit_meta(
|
|||
|
|
agent=agent,
|
|||
|
|
schema_name=schema_name,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
t0=t0,
|
|||
|
|
attempts=attempts_used,
|
|||
|
|
parse_ok=False,
|
|||
|
|
used_fallback=used_fb,
|
|||
|
|
error_kind=e.kind,
|
|||
|
|
)
|
|||
|
|
if agent_verbose_enabled():
|
|||
|
|
log_agent_payload(
|
|||
|
|
logger,
|
|||
|
|
f"{agent}.prompt",
|
|||
|
|
ensure_json_object_prompt_has_json_keyword(prompt),
|
|||
|
|
)
|
|||
|
|
log_agent_payload(logger, f"{agent}.response", raw)
|
|||
|
|
if fallback_factory is not None:
|
|||
|
|
return fallback_factory()
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.bind(agent=agent).exception("allm_json_call invoke error: {}", e)
|
|||
|
|
used_fb = fallback_factory is not None
|
|||
|
|
_emit_meta(
|
|||
|
|
agent=agent,
|
|||
|
|
schema_name=schema_name,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
t0=t0,
|
|||
|
|
attempts=attempts_used,
|
|||
|
|
parse_ok=False,
|
|||
|
|
used_fallback=used_fb,
|
|||
|
|
error_kind="invoke",
|
|||
|
|
)
|
|||
|
|
if agent_verbose_enabled():
|
|||
|
|
log_agent_payload(
|
|||
|
|
logger,
|
|||
|
|
f"{agent}.prompt",
|
|||
|
|
ensure_json_object_prompt_has_json_keyword(prompt),
|
|||
|
|
)
|
|||
|
|
log_agent_payload(logger, f"{agent}.response", raw)
|
|||
|
|
if fallback_factory is not None:
|
|||
|
|
return fallback_factory()
|
|||
|
|
raise LLMCallError(
|
|||
|
|
"invoke",
|
|||
|
|
str(e),
|
|||
|
|
raw_content=raw[:4096] if raw else None,
|
|||
|
|
) from e
|
|||
|
|
|
|||
|
|
|
|||
|
|
__all__ = [
|
|||
|
|
"LLMCallError",
|
|||
|
|
"LLMCallMeta",
|
|||
|
|
"allm_json_call",
|
|||
|
|
"llm_json_call",
|
|||
|
|
]
|