feat(api): 统一 LLM JSON 调用层 llm_json_call,按域 Schema 迁移 chat/memoir agents

This commit is contained in:
Kevin
2026-04-03 13:34:27 +08:00
parent 41518bda11
commit 43d1689e9c
28 changed files with 1006 additions and 352 deletions

402
api/app/core/llm_call.py Normal file
View File

@@ -0,0 +1,402 @@
"""
Schema-driven LLM JSON 调用:统一 bind `json_object`、空输出重试、解析校验、结构化日志。
`extract_json_payload` 仅在 **decode 失败时** 作为一次兼容性重试;命中时打
`event=llm_json_compat_strip_hit` 便于后续下线该路径(见计划 Step 13生产观测零命中后再删 compat
"""
from __future__ import annotations
import hashlib
import json
import time
from dataclasses import dataclass
from typing import Any, Callable, Literal, TypeVar
from pydantic import BaseModel, ValidationError
from app.core.agent_logging import agent_verbose_enabled, log_agent_payload
from app.core.json_utils import extract_json_payload
from app.core.langchain_llm import (
bind_json_object_mode,
ensure_json_object_prompt_has_json_keyword,
)
from app.core.logging import get_logger
logger = get_logger(__name__)
T = TypeVar("T", bound=BaseModel)
ErrorKind = Literal["invoke", "decode", "validation"]
class LLMCallError(Exception):
"""未提供 fallback_factory 且调用链失败时抛出。"""
def __init__(
self,
kind: ErrorKind,
message: str,
*,
raw_content: str | None = None,
) -> None:
super().__init__(message)
self.kind: ErrorKind = kind
self.raw_content: str | None = raw_content
@dataclass(frozen=True)
class LLMCallMeta:
agent: str
schema_name: str
max_tokens: int
duration_ms: float
attempts: int
parse_ok: bool
used_fallback: bool
error_kind: str | None
def _prompt_sha12(prompt: str) -> str:
return hashlib.sha256((prompt or "").encode("utf-8")).hexdigest()[:12]
def _invoke_raw_sync(
llm: Any,
prompt: str,
*,
max_tokens: int,
agent: str,
retry_empty: bool,
) -> tuple[str, int]:
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
tag = agent or "json_object"
sha = _prompt_sha12(prompt_for_api)
attempts = 2 if retry_empty else 1
for attempt in range(attempts):
response = bound.invoke(prompt_for_api)
content = (getattr(response, "content", None) or "").strip()
if content:
if attempt > 0:
logger.info(
"json_object 空内容重试成功 agent={} prompt_sha12={}",
tag,
sha,
)
return content, attempt + 1
if attempt == 0 and retry_empty:
logger.warning(
"json_object 返回空 content将重试 agent={} attempt={} prompt_sha12={}",
tag,
attempt,
sha,
)
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
return "", attempts
async def _invoke_raw_async(
llm: Any,
prompt: str,
*,
max_tokens: int,
agent: str,
retry_empty: bool,
) -> tuple[str, int]:
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
tag = agent or "json_object"
sha = _prompt_sha12(prompt_for_api)
attempts = 2 if retry_empty else 1
for attempt in range(attempts):
response = await bound.ainvoke(prompt_for_api)
content = (getattr(response, "content", None) or "").strip()
if content:
if attempt > 0:
logger.info(
"json_object 空内容重试成功 agent={} prompt_sha12={}",
tag,
sha,
)
return content, attempt + 1
if attempt == 0 and retry_empty:
logger.warning(
"json_object 返回空 content将重试 agent={} attempt={} prompt_sha12={}",
tag,
attempt,
sha,
)
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
return "", attempts
def _parse_and_validate(
raw: str,
schema: type[T],
*,
agent: str,
) -> T:
s = (raw or "").strip()
if not s:
raise LLMCallError(
"decode", "empty llm content for json parse", raw_content=raw
)
data: Any
try:
data = json.loads(s)
except json.JSONDecodeError:
stripped = extract_json_payload(s)
if stripped != s:
logger.warning(
"event=llm_json_compat_strip_hit agent={} prompt_kind=decode_retry",
agent,
)
try:
data = json.loads(stripped)
except json.JSONDecodeError as e:
raise LLMCallError(
"decode",
f"json decode failed: {e}",
raw_content=s[:4096],
) from e
try:
return schema.model_validate(data)
except ValidationError as e:
raise LLMCallError(
"validation",
f"pydantic validation failed: {e}",
raw_content=s[:4096],
) from e
def _emit_meta(
*,
agent: str,
schema_name: str,
max_tokens: int,
t0: float,
attempts: int,
parse_ok: bool,
used_fallback: bool,
error_kind: str | None,
) -> None:
meta = LLMCallMeta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
duration_ms=(time.perf_counter() - t0) * 1000,
attempts=attempts,
parse_ok=parse_ok,
used_fallback=used_fallback,
error_kind=error_kind,
)
logger.bind(
event="llm_json_call",
agent=meta.agent,
schema=meta.schema_name,
max_tokens=meta.max_tokens,
duration_ms=round(meta.duration_ms, 2),
attempts=meta.attempts,
parse_ok=meta.parse_ok,
used_fallback=meta.used_fallback,
error_kind=meta.error_kind,
).info("llm_json_call_done")
def llm_json_call(
llm: Any,
prompt: str,
schema: type[T],
*,
max_tokens: int,
agent: str,
fallback_factory: Callable[[], T] | None = None,
retry_empty: bool = True,
) -> T:
"""同步invoke → 解析 JSON → `schema.model_validate`;失败时 `fallback_factory` 或 `LLMCallError`。"""
t0 = time.perf_counter()
schema_name = getattr(schema, "__name__", str(schema))
attempts_used = 0
raw = ""
try:
raw, attempts_used = _invoke_raw_sync(
llm,
prompt,
max_tokens=max_tokens,
agent=agent,
retry_empty=retry_empty,
)
out = _parse_and_validate(raw, schema, agent=agent)
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=True,
used_fallback=False,
error_kind=None,
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
return out
except LLMCallError as e:
used_fb = fallback_factory is not None
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=False,
used_fallback=used_fb,
error_kind=e.kind,
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
if fallback_factory is not None:
return fallback_factory()
raise
except Exception as e:
logger.bind(agent=agent).exception("llm_json_call invoke error: {}", e)
used_fb = fallback_factory is not None
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=False,
used_fallback=used_fb,
error_kind="invoke",
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
if fallback_factory is not None:
return fallback_factory()
raise LLMCallError(
"invoke",
str(e),
raw_content=raw[:4096] if raw else None,
) from e
async def allm_json_call(
llm: Any,
prompt: str,
schema: type[T],
*,
max_tokens: int,
agent: str,
fallback_factory: Callable[[], T] | None = None,
retry_empty: bool = True,
) -> T:
"""异步版,语义与 `llm_json_call` 一致。"""
t0 = time.perf_counter()
schema_name = getattr(schema, "__name__", str(schema))
attempts_used = 0
raw = ""
try:
raw, attempts_used = await _invoke_raw_async(
llm,
prompt,
max_tokens=max_tokens,
agent=agent,
retry_empty=retry_empty,
)
out = _parse_and_validate(raw, schema, agent=agent)
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=True,
used_fallback=False,
error_kind=None,
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
return out
except LLMCallError as e:
used_fb = fallback_factory is not None
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=False,
used_fallback=used_fb,
error_kind=e.kind,
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
if fallback_factory is not None:
return fallback_factory()
raise
except Exception as e:
logger.bind(agent=agent).exception("allm_json_call invoke error: {}", e)
used_fb = fallback_factory is not None
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=False,
used_fallback=used_fb,
error_kind="invoke",
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
if fallback_factory is not None:
return fallback_factory()
raise LLMCallError(
"invoke",
str(e),
raw_content=raw[:4096] if raw else None,
) from e
__all__ = [
"LLMCallError",
"LLMCallMeta",
"allm_json_call",
"llm_json_call",
]