Files
life-echo/api/app/core/llm_call.py

403 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Schema-driven LLM JSON 调用:统一 bind `json_object`、空输出重试、解析校验、结构化日志。
`extract_json_payload` 仅在 **decode 失败时** 作为一次兼容性重试;命中时打
`event=llm_json_compat_strip_hit` 便于后续下线该路径(见计划 Step 13生产观测零命中后再删 compat
"""
from __future__ import annotations
import hashlib
import json
import time
from dataclasses import dataclass
from typing import Any, Callable, Literal, TypeVar
from pydantic import BaseModel, ValidationError
from app.core.agent_logging import agent_verbose_enabled, log_agent_payload
from app.core.json_utils import extract_json_payload
from app.core.langchain_llm import (
bind_json_object_mode,
ensure_json_object_prompt_has_json_keyword,
)
from app.core.logging import get_logger
logger = get_logger(__name__)
T = TypeVar("T", bound=BaseModel)
ErrorKind = Literal["invoke", "decode", "validation"]
class LLMCallError(Exception):
"""未提供 fallback_factory 且调用链失败时抛出。"""
def __init__(
self,
kind: ErrorKind,
message: str,
*,
raw_content: str | None = None,
) -> None:
super().__init__(message)
self.kind: ErrorKind = kind
self.raw_content: str | None = raw_content
@dataclass(frozen=True)
class LLMCallMeta:
agent: str
schema_name: str
max_tokens: int
duration_ms: float
attempts: int
parse_ok: bool
used_fallback: bool
error_kind: str | None
def _prompt_sha12(prompt: str) -> str:
return hashlib.sha256((prompt or "").encode("utf-8")).hexdigest()[:12]
def _invoke_raw_sync(
llm: Any,
prompt: str,
*,
max_tokens: int,
agent: str,
retry_empty: bool,
) -> tuple[str, int]:
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
tag = agent or "json_object"
sha = _prompt_sha12(prompt_for_api)
attempts = 2 if retry_empty else 1
for attempt in range(attempts):
response = bound.invoke(prompt_for_api)
content = (getattr(response, "content", None) or "").strip()
if content:
if attempt > 0:
logger.info(
"json_object 空内容重试成功 agent={} prompt_sha12={}",
tag,
sha,
)
return content, attempt + 1
if attempt == 0 and retry_empty:
logger.warning(
"json_object 返回空 content将重试 agent={} attempt={} prompt_sha12={}",
tag,
attempt,
sha,
)
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
return "", attempts
async def _invoke_raw_async(
llm: Any,
prompt: str,
*,
max_tokens: int,
agent: str,
retry_empty: bool,
) -> tuple[str, int]:
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
tag = agent or "json_object"
sha = _prompt_sha12(prompt_for_api)
attempts = 2 if retry_empty else 1
for attempt in range(attempts):
response = await bound.ainvoke(prompt_for_api)
content = (getattr(response, "content", None) or "").strip()
if content:
if attempt > 0:
logger.info(
"json_object 空内容重试成功 agent={} prompt_sha12={}",
tag,
sha,
)
return content, attempt + 1
if attempt == 0 and retry_empty:
logger.warning(
"json_object 返回空 content将重试 agent={} attempt={} prompt_sha12={}",
tag,
attempt,
sha,
)
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
return "", attempts
def _parse_and_validate(
raw: str,
schema: type[T],
*,
agent: str,
) -> T:
s = (raw or "").strip()
if not s:
raise LLMCallError(
"decode", "empty llm content for json parse", raw_content=raw
)
data: Any
try:
data = json.loads(s)
except json.JSONDecodeError:
stripped = extract_json_payload(s)
if stripped != s:
logger.warning(
"event=llm_json_compat_strip_hit agent={} prompt_kind=decode_retry",
agent,
)
try:
data = json.loads(stripped)
except json.JSONDecodeError as e:
raise LLMCallError(
"decode",
f"json decode failed: {e}",
raw_content=s[:4096],
) from e
try:
return schema.model_validate(data)
except ValidationError as e:
raise LLMCallError(
"validation",
f"pydantic validation failed: {e}",
raw_content=s[:4096],
) from e
def _emit_meta(
*,
agent: str,
schema_name: str,
max_tokens: int,
t0: float,
attempts: int,
parse_ok: bool,
used_fallback: bool,
error_kind: str | None,
) -> None:
meta = LLMCallMeta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
duration_ms=(time.perf_counter() - t0) * 1000,
attempts=attempts,
parse_ok=parse_ok,
used_fallback=used_fallback,
error_kind=error_kind,
)
logger.bind(
event="llm_json_call",
agent=meta.agent,
schema=meta.schema_name,
max_tokens=meta.max_tokens,
duration_ms=round(meta.duration_ms, 2),
attempts=meta.attempts,
parse_ok=meta.parse_ok,
used_fallback=meta.used_fallback,
error_kind=meta.error_kind,
).info("llm_json_call_done")
def llm_json_call(
llm: Any,
prompt: str,
schema: type[T],
*,
max_tokens: int,
agent: str,
fallback_factory: Callable[[], T] | None = None,
retry_empty: bool = True,
) -> T:
"""同步invoke → 解析 JSON → `schema.model_validate`;失败时 `fallback_factory` 或 `LLMCallError`。"""
t0 = time.perf_counter()
schema_name = getattr(schema, "__name__", str(schema))
attempts_used = 0
raw = ""
try:
raw, attempts_used = _invoke_raw_sync(
llm,
prompt,
max_tokens=max_tokens,
agent=agent,
retry_empty=retry_empty,
)
out = _parse_and_validate(raw, schema, agent=agent)
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=True,
used_fallback=False,
error_kind=None,
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
return out
except LLMCallError as e:
used_fb = fallback_factory is not None
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=False,
used_fallback=used_fb,
error_kind=e.kind,
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
if fallback_factory is not None:
return fallback_factory()
raise
except Exception as e:
logger.bind(agent=agent).exception("llm_json_call invoke error: {}", e)
used_fb = fallback_factory is not None
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=False,
used_fallback=used_fb,
error_kind="invoke",
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
if fallback_factory is not None:
return fallback_factory()
raise LLMCallError(
"invoke",
str(e),
raw_content=raw[:4096] if raw else None,
) from e
async def allm_json_call(
llm: Any,
prompt: str,
schema: type[T],
*,
max_tokens: int,
agent: str,
fallback_factory: Callable[[], T] | None = None,
retry_empty: bool = True,
) -> T:
"""异步版,语义与 `llm_json_call` 一致。"""
t0 = time.perf_counter()
schema_name = getattr(schema, "__name__", str(schema))
attempts_used = 0
raw = ""
try:
raw, attempts_used = await _invoke_raw_async(
llm,
prompt,
max_tokens=max_tokens,
agent=agent,
retry_empty=retry_empty,
)
out = _parse_and_validate(raw, schema, agent=agent)
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=True,
used_fallback=False,
error_kind=None,
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
return out
except LLMCallError as e:
used_fb = fallback_factory is not None
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=False,
used_fallback=used_fb,
error_kind=e.kind,
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
if fallback_factory is not None:
return fallback_factory()
raise
except Exception as e:
logger.bind(agent=agent).exception("allm_json_call invoke error: {}", e)
used_fb = fallback_factory is not None
_emit_meta(
agent=agent,
schema_name=schema_name,
max_tokens=max_tokens,
t0=t0,
attempts=attempts_used,
parse_ok=False,
used_fallback=used_fb,
error_kind="invoke",
)
if agent_verbose_enabled():
log_agent_payload(
logger,
f"{agent}.prompt",
ensure_json_object_prompt_has_json_keyword(prompt),
)
log_agent_payload(logger, f"{agent}.response", raw)
if fallback_factory is not None:
return fallback_factory()
raise LLMCallError(
"invoke",
str(e),
raw_content=raw[:4096] if raw else None,
) from e
__all__ = [
"LLMCallError",
"LLMCallMeta",
"allm_json_call",
"llm_json_call",
]