Files
life-echo/api/app/core/langchain_llm.py

173 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
与 `get_llm_provider().langchain_llm` 配合使用的 LangChain Runnable 约定。
langchain-openai 要求用顶层 `response_format` 绑定 JSON 模式,禁止对 `.bind()` 传入
`model_kwargs={"response_format": ...}`(会错误传入底层 `completions.create`)。
"""
from __future__ import annotations
import hashlib
import time
from typing import Any
from app.core.agent_logging import (
agent_summary_enabled,
agent_verbose_enabled,
log_agent_payload,
)
from app.core.logging import get_logger
logger = get_logger(__name__)
# OpenAI / DeepSeek使用 response_format=json_object 时prompt 须含子串「json」
# (见 DeepSeek JSON Output 指南)。
_JSON_OBJECT_PROMPT_SUFFIX = (
"\n\n【JSON】只输出一个合法 JSON 对象,不要其它说明文字或 markdown。"
)
def ensure_json_object_prompt_has_json_keyword(prompt: str) -> str:
"""
若整段 prompt 中未出现 ``json``(大小写不敏感),追加一行合规提示。
供所有 ``response_format: json_object`` 调用在发请求前统一处理。
"""
p = prompt or ""
if "json" in p.casefold():
return p
return f"{p.rstrip()}{_JSON_OBJECT_PROMPT_SUFFIX}"
def bind_json_object_mode(llm: Any, *, max_tokens: int) -> Any:
"""返回绑定 `response_format=json_object` 与 `max_tokens` 的 Runnable通常为 ChatOpenAI"""
return llm.bind(
response_format={"type": "json_object"},
max_tokens=max_tokens,
)
def _prompt_sha12(prompt: str) -> str:
return hashlib.sha256((prompt or "").encode("utf-8")).hexdigest()[:12]
def invoke_json_object(
llm: Any,
prompt: str,
*,
max_tokens: int,
agent: str | None = None,
retry_empty: bool = True,
) -> str:
"""
同步调用 JSON object 模式;空 content 时可选重试一次(缓解 DeepSeek 偶发空输出)。
仅依赖 bind_json_object_mode不引用 features。
"""
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
tag = agent or "json_object"
sha = _prompt_sha12(prompt_for_api)
attempts = 2 if retry_empty else 1
t0 = time.perf_counter()
last_content = ""
for attempt in range(attempts):
response = bound.invoke(prompt_for_api)
content = (getattr(response, "content", None) or "").strip()
last_content = content
if content:
if attempt > 0:
logger.info(
"json_object 空内容重试成功 agent={} prompt_sha12={}",
tag,
sha,
)
_log_json_object_done(
tag, sha, prompt_for_api, content, attempt + 1, t0, success=True
)
return content
if attempt == 0 and retry_empty:
logger.warning(
"json_object 返回空 content将重试 agent={} attempt={} prompt_sha12={}",
tag,
attempt,
sha,
)
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
_log_json_object_done(
tag, sha, prompt_for_api, last_content, attempts, t0, success=False
)
return ""
async def ainvoke_json_object(
llm: Any,
prompt: str,
*,
max_tokens: int,
agent: str | None = None,
retry_empty: bool = True,
) -> str:
"""异步版 `invoke_json_object`。"""
prompt_for_api = ensure_json_object_prompt_has_json_keyword(prompt)
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
tag = agent or "json_object"
sha = _prompt_sha12(prompt_for_api)
attempts = 2 if retry_empty else 1
t0 = time.perf_counter()
last_content = ""
for attempt in range(attempts):
response = await bound.ainvoke(prompt_for_api)
content = (getattr(response, "content", None) or "").strip()
last_content = content
if content:
if attempt > 0:
logger.info(
"json_object 空内容重试成功 agent={} prompt_sha12={}",
tag,
sha,
)
_log_json_object_done(
tag, sha, prompt_for_api, content, attempt + 1, t0, success=True
)
return content
if attempt == 0 and retry_empty:
logger.warning(
"json_object 返回空 content将重试 agent={} attempt={} prompt_sha12={}",
tag,
attempt,
sha,
)
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
_log_json_object_done(
tag, sha, prompt_for_api, last_content, attempts, t0, success=False
)
return ""
def _log_json_object_done(
tag: str,
sha: str,
prompt: str,
content: str,
attempts_used: int,
t0: float,
*,
success: bool,
) -> None:
ms = (time.perf_counter() - t0) * 1000
if agent_summary_enabled():
prompt_chars = len(prompt or "")
logger.info(
"llm_json_object agent={} prompt_sha12={} duration_ms={:.2f} "
"prompt_char_count={} response_len={} attempts={} success={}",
tag,
sha,
ms,
prompt_chars,
len(content or ""),
attempts_used,
success,
)
if agent_verbose_enabled():
log_agent_payload(logger, f"{tag}.prompt", prompt)
log_agent_payload(logger, f"{tag}.response", content)