- Judge baseline excerpt and library chapter separately; build_memoir_compare_summary for gate, nine-dim and leaf deltas. - Memoir SSE chapter payload: baseline_judge, compare_summary, baseline_judge_error. - MemoirJudgeOutput: loose score coercion and post-validate clamp; memoir judge prompt caps from settings. - app-eval-web: two-column MemoirScoreCard layout, MemoirCompareSummary, chapter blocks and CSS. - Add memoir_compare_summary, log_events, celery_log_context, memoir_pipeline_progress; tests and migration 0014. - Misc: memory/evidence and enrichment paths, task/orchestrator updates, internal-eval docs, env examples.
461 lines
14 KiB
Python
461 lines
14 KiB
Python
"""
|
||
Schema-driven LLM JSON 调用:统一 bind `json_object`、空输出重试、解析校验、结构化日志。
|
||
|
||
`extract_json_payload` 仅在 **decode 失败时** 作为一次兼容性重试;命中时打
|
||
`event=llm_json_compat_strip_hit` 便于后续下线该路径(见计划 Step 13;生产观测零命中后再删 compat)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import json
|
||
import time
|
||
from dataclasses import dataclass
|
||
from typing import Any, Callable, Literal, TypeVar
|
||
|
||
from pydantic import BaseModel, ValidationError
|
||
|
||
try:
|
||
from openai import (
|
||
ContentFilterFinishReasonError as _OpenAIContentFilterFinishReasonError,
|
||
)
|
||
except ImportError: # 兼容性:旧版 SDK 无此类
|
||
_OpenAIContentFilterFinishReasonError = None
|
||
|
||
from app.core.agent_logging import agent_verbose_enabled, log_agent_payload
|
||
from app.core.json_utils import extract_json_payload
|
||
from app.core.langchain_llm import (
|
||
bind_json_object_mode,
|
||
ensure_json_object_prompt_has_json_keyword,
|
||
)
|
||
from app.core.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
T = TypeVar("T", bound=BaseModel)
|
||
|
||
ErrorKind = Literal["invoke", "decode", "validation"]
|
||
|
||
|
||
class LLMCallError(Exception):
|
||
"""未提供 fallback_factory 且调用链失败时抛出。"""
|
||
|
||
def __init__(
|
||
self,
|
||
kind: ErrorKind,
|
||
message: str,
|
||
*,
|
||
raw_content: str | None = None,
|
||
) -> None:
|
||
super().__init__(message)
|
||
self.kind: ErrorKind = kind
|
||
self.raw_content: str | None = raw_content
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class LLMCallMeta:
|
||
agent: str
|
||
schema_name: str
|
||
max_tokens: int
|
||
duration_ms: float
|
||
attempts: int
|
||
parse_ok: bool
|
||
used_fallback: bool
|
||
error_kind: str | None
|
||
|
||
|
||
def _prompt_sha12(prompt: str) -> str:
|
||
return hashlib.sha256((prompt or "").encode("utf-8")).hexdigest()[:12]
|
||
|
||
|
||
def _iter_exception_chain(exc: BaseException):
|
||
"""包含自身与 ``__cause__`` / ``__context__`` 链,去重防环。"""
|
||
seen: set[int] = set()
|
||
cur: BaseException | None = exc
|
||
while cur is not None and id(cur) not in seen:
|
||
yield cur
|
||
seen.add(id(cur))
|
||
cur = cur.__cause__ or cur.__context__
|
||
|
||
|
||
def _is_content_filter_refusal(exc: BaseException) -> bool:
|
||
"""OpenAI / Azure 等内容审核拦截:无模型 JSON 可解析,属可预期失败,不宜打 ERROR 堆栈。"""
|
||
for e in _iter_exception_chain(exc):
|
||
if _OpenAIContentFilterFinishReasonError is not None and isinstance(
|
||
e,
|
||
_OpenAIContentFilterFinishReasonError,
|
||
):
|
||
return True
|
||
msg = str(e).lower()
|
||
if "content filter" in msg and (
|
||
"reject" in msg or "blocked" in msg or "filter" in msg
|
||
):
|
||
return True
|
||
return False
|
||
|
||
|
||
_LLM_MSG_CONTENT_FILTER = (
|
||
"模型输出被服务商内容安全策略拦截(content filter),通常与提示或上下文中触发了合规扫描有关;"
|
||
"可尝试更换模型、缩短送入模型的正文/证据节选,或在服务商控制台调整内容过滤策略。"
|
||
)
|
||
|
||
|
||
def _format_llm_invoke_error_message(exc: BaseException) -> str:
|
||
if _is_content_filter_refusal(exc):
|
||
return _LLM_MSG_CONTENT_FILTER
|
||
return str(exc)
|
||
|
||
|
||
def _log_invoke_failure(*, agent: str, exc: BaseException, sync: bool) -> None:
|
||
if _is_content_filter_refusal(exc):
|
||
logger.info(
|
||
"event=llm_content_filter_blocked agent={} sync={} detail={}",
|
||
agent,
|
||
sync,
|
||
str(exc)[:500],
|
||
)
|
||
return
|
||
tag = "llm_json_call" if sync else "allm_json_call"
|
||
logger.bind(agent=agent).exception("{} invoke error: {}", tag, exc)
|
||
|
||
|
||
def _invoke_raw_sync(
|
||
llm: Any,
|
||
prompt: str,
|
||
*,
|
||
max_tokens: int,
|
||
agent: str,
|
||
retry_empty: bool,
|
||
) -> tuple[str, int]:
|
||
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
|
||
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
|
||
tag = agent or "json_object"
|
||
sha = _prompt_sha12(prompt_for_api)
|
||
attempts = 2 if retry_empty else 1
|
||
for attempt in range(attempts):
|
||
response = bound.invoke(prompt_for_api)
|
||
content = (getattr(response, "content", None) or "").strip()
|
||
if content:
|
||
if attempt > 0:
|
||
logger.info(
|
||
"json_object 空内容重试成功 agent={} prompt_sha12={}",
|
||
tag,
|
||
sha,
|
||
)
|
||
return content, attempt + 1
|
||
if attempt == 0 and retry_empty:
|
||
logger.warning(
|
||
"json_object 返回空 content,将重试 agent={} attempt={} prompt_sha12={}",
|
||
tag,
|
||
attempt,
|
||
sha,
|
||
)
|
||
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
|
||
return "", attempts
|
||
|
||
|
||
async def _invoke_raw_async(
|
||
llm: Any,
|
||
prompt: str,
|
||
*,
|
||
max_tokens: int,
|
||
agent: str,
|
||
retry_empty: bool,
|
||
) -> tuple[str, int]:
|
||
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
|
||
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
|
||
tag = agent or "json_object"
|
||
sha = _prompt_sha12(prompt_for_api)
|
||
attempts = 2 if retry_empty else 1
|
||
for attempt in range(attempts):
|
||
response = await bound.ainvoke(prompt_for_api)
|
||
content = (getattr(response, "content", None) or "").strip()
|
||
if content:
|
||
if attempt > 0:
|
||
logger.info(
|
||
"json_object 空内容重试成功 agent={} prompt_sha12={}",
|
||
tag,
|
||
sha,
|
||
)
|
||
return content, attempt + 1
|
||
if attempt == 0 and retry_empty:
|
||
logger.warning(
|
||
"json_object 返回空 content,将重试 agent={} attempt={} prompt_sha12={}",
|
||
tag,
|
||
attempt,
|
||
sha,
|
||
)
|
||
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
|
||
return "", attempts
|
||
|
||
|
||
def _parse_and_validate(
|
||
raw: str,
|
||
schema: type[T],
|
||
*,
|
||
agent: str,
|
||
) -> T:
|
||
s = (raw or "").strip()
|
||
if not s:
|
||
raise LLMCallError(
|
||
"decode", "empty llm content for json parse", raw_content=raw
|
||
)
|
||
|
||
data: Any
|
||
try:
|
||
data = json.loads(s)
|
||
except json.JSONDecodeError:
|
||
stripped = extract_json_payload(s)
|
||
if stripped != s:
|
||
logger.warning(
|
||
"event=llm_json_compat_strip_hit agent={} prompt_kind=decode_retry",
|
||
agent,
|
||
)
|
||
try:
|
||
data = json.loads(stripped)
|
||
except json.JSONDecodeError as e:
|
||
raise LLMCallError(
|
||
"decode",
|
||
f"json decode failed: {e}",
|
||
raw_content=s[:4096],
|
||
) from e
|
||
|
||
try:
|
||
return schema.model_validate(data)
|
||
except ValidationError as e:
|
||
raise LLMCallError(
|
||
"validation",
|
||
f"pydantic validation failed: {e}",
|
||
raw_content=s[:4096],
|
||
) from e
|
||
|
||
|
||
def _emit_meta(
|
||
*,
|
||
agent: str,
|
||
schema_name: str,
|
||
max_tokens: int,
|
||
t0: float,
|
||
attempts: int,
|
||
parse_ok: bool,
|
||
used_fallback: bool,
|
||
error_kind: str | None,
|
||
) -> None:
|
||
meta = LLMCallMeta(
|
||
agent=agent,
|
||
schema_name=schema_name,
|
||
max_tokens=max_tokens,
|
||
duration_ms=(time.perf_counter() - t0) * 1000,
|
||
attempts=attempts,
|
||
parse_ok=parse_ok,
|
||
used_fallback=used_fallback,
|
||
error_kind=error_kind,
|
||
)
|
||
logger.bind(
|
||
event="llm_json_call",
|
||
agent=meta.agent,
|
||
schema=meta.schema_name,
|
||
max_tokens=meta.max_tokens,
|
||
duration_ms=round(meta.duration_ms, 2),
|
||
attempts=meta.attempts,
|
||
parse_ok=meta.parse_ok,
|
||
used_fallback=meta.used_fallback,
|
||
error_kind=meta.error_kind,
|
||
).info("llm_json_call_done")
|
||
|
||
|
||
def llm_json_call(
|
||
llm: Any,
|
||
prompt: str,
|
||
schema: type[T],
|
||
*,
|
||
max_tokens: int,
|
||
agent: str,
|
||
fallback_factory: Callable[[], T] | None = None,
|
||
retry_empty: bool = True,
|
||
) -> T:
|
||
"""同步:invoke → 解析 JSON → `schema.model_validate`;失败时 `fallback_factory` 或 `LLMCallError`。"""
|
||
t0 = time.perf_counter()
|
||
schema_name = getattr(schema, "__name__", str(schema))
|
||
attempts_used = 0
|
||
raw = ""
|
||
|
||
try:
|
||
raw, attempts_used = _invoke_raw_sync(
|
||
llm,
|
||
prompt,
|
||
max_tokens=max_tokens,
|
||
agent=agent,
|
||
retry_empty=retry_empty,
|
||
)
|
||
out = _parse_and_validate(raw, schema, agent=agent)
|
||
_emit_meta(
|
||
agent=agent,
|
||
schema_name=schema_name,
|
||
max_tokens=max_tokens,
|
||
t0=t0,
|
||
attempts=attempts_used,
|
||
parse_ok=True,
|
||
used_fallback=False,
|
||
error_kind=None,
|
||
)
|
||
if agent_verbose_enabled():
|
||
log_agent_payload(
|
||
logger,
|
||
f"{agent}.prompt",
|
||
ensure_json_object_prompt_has_json_keyword(prompt),
|
||
)
|
||
log_agent_payload(logger, f"{agent}.response", raw)
|
||
return out
|
||
except LLMCallError as e:
|
||
used_fb = fallback_factory is not None
|
||
_emit_meta(
|
||
agent=agent,
|
||
schema_name=schema_name,
|
||
max_tokens=max_tokens,
|
||
t0=t0,
|
||
attempts=attempts_used,
|
||
parse_ok=False,
|
||
used_fallback=used_fb,
|
||
error_kind=e.kind,
|
||
)
|
||
if agent_verbose_enabled():
|
||
log_agent_payload(
|
||
logger,
|
||
f"{agent}.prompt",
|
||
ensure_json_object_prompt_has_json_keyword(prompt),
|
||
)
|
||
log_agent_payload(logger, f"{agent}.response", raw)
|
||
if fallback_factory is not None:
|
||
return fallback_factory()
|
||
raise
|
||
except Exception as e:
|
||
_log_invoke_failure(agent=agent, exc=e, sync=True)
|
||
used_fb = fallback_factory is not None
|
||
_emit_meta(
|
||
agent=agent,
|
||
schema_name=schema_name,
|
||
max_tokens=max_tokens,
|
||
t0=t0,
|
||
attempts=attempts_used,
|
||
parse_ok=False,
|
||
used_fallback=used_fb,
|
||
error_kind="invoke",
|
||
)
|
||
if agent_verbose_enabled():
|
||
log_agent_payload(
|
||
logger,
|
||
f"{agent}.prompt",
|
||
ensure_json_object_prompt_has_json_keyword(prompt),
|
||
)
|
||
log_agent_payload(logger, f"{agent}.response", raw)
|
||
if fallback_factory is not None:
|
||
return fallback_factory()
|
||
raise LLMCallError(
|
||
"invoke",
|
||
_format_llm_invoke_error_message(e),
|
||
raw_content=raw[:4096] if raw else None,
|
||
) from e
|
||
|
||
|
||
async def allm_json_call(
|
||
llm: Any,
|
||
prompt: str,
|
||
schema: type[T],
|
||
*,
|
||
max_tokens: int,
|
||
agent: str,
|
||
fallback_factory: Callable[[], T] | None = None,
|
||
retry_empty: bool = True,
|
||
) -> T:
|
||
"""异步版,语义与 `llm_json_call` 一致。"""
|
||
t0 = time.perf_counter()
|
||
schema_name = getattr(schema, "__name__", str(schema))
|
||
attempts_used = 0
|
||
raw = ""
|
||
|
||
try:
|
||
raw, attempts_used = await _invoke_raw_async(
|
||
llm,
|
||
prompt,
|
||
max_tokens=max_tokens,
|
||
agent=agent,
|
||
retry_empty=retry_empty,
|
||
)
|
||
out = _parse_and_validate(raw, schema, agent=agent)
|
||
_emit_meta(
|
||
agent=agent,
|
||
schema_name=schema_name,
|
||
max_tokens=max_tokens,
|
||
t0=t0,
|
||
attempts=attempts_used,
|
||
parse_ok=True,
|
||
used_fallback=False,
|
||
error_kind=None,
|
||
)
|
||
if agent_verbose_enabled():
|
||
log_agent_payload(
|
||
logger,
|
||
f"{agent}.prompt",
|
||
ensure_json_object_prompt_has_json_keyword(prompt),
|
||
)
|
||
log_agent_payload(logger, f"{agent}.response", raw)
|
||
return out
|
||
except LLMCallError as e:
|
||
used_fb = fallback_factory is not None
|
||
_emit_meta(
|
||
agent=agent,
|
||
schema_name=schema_name,
|
||
max_tokens=max_tokens,
|
||
t0=t0,
|
||
attempts=attempts_used,
|
||
parse_ok=False,
|
||
used_fallback=used_fb,
|
||
error_kind=e.kind,
|
||
)
|
||
if agent_verbose_enabled():
|
||
log_agent_payload(
|
||
logger,
|
||
f"{agent}.prompt",
|
||
ensure_json_object_prompt_has_json_keyword(prompt),
|
||
)
|
||
log_agent_payload(logger, f"{agent}.response", raw)
|
||
if fallback_factory is not None:
|
||
return fallback_factory()
|
||
raise
|
||
except Exception as e:
|
||
_log_invoke_failure(agent=agent, exc=e, sync=False)
|
||
used_fb = fallback_factory is not None
|
||
_emit_meta(
|
||
agent=agent,
|
||
schema_name=schema_name,
|
||
max_tokens=max_tokens,
|
||
t0=t0,
|
||
attempts=attempts_used,
|
||
parse_ok=False,
|
||
used_fallback=used_fb,
|
||
error_kind="invoke",
|
||
)
|
||
if agent_verbose_enabled():
|
||
log_agent_payload(
|
||
logger,
|
||
f"{agent}.prompt",
|
||
ensure_json_object_prompt_has_json_keyword(prompt),
|
||
)
|
||
log_agent_payload(logger, f"{agent}.response", raw)
|
||
if fallback_factory is not None:
|
||
return fallback_factory()
|
||
raise LLMCallError(
|
||
"invoke",
|
||
_format_llm_invoke_error_message(e),
|
||
raw_content=raw[:4096] if raw else None,
|
||
) from e
|
||
|
||
|
||
__all__ = [
|
||
"LLMCallError",
|
||
"LLMCallMeta",
|
||
"allm_json_call",
|
||
"llm_json_call",
|
||
]
|