Files
life-echo/api/app/core/langchain_llm.py
Kevin a3f61fcc0f feat(api+app): 对话阶段化、回忆录流水线与客户端会话体验
- DB: segments 用户输入文本(Alembic 0002)
- Chat: 阶段检测/阶段提示/回复限制,编排与访谈/画像 prompts 调整
- Memoir: 忠实度检查 agent,叙事与分类等链路更新
- Core: agent 日志、Alembic 启动、LangChain/日志/配置等
- Story: time_hints;Memory 检索与相关测试
- Expo: 助手头像、会话页与消息拆分、实时会话与文案/i18n
- Docs/scripts/tests: 迁移脚本、LLM JSON/记忆检索文档、新增单测
2026-03-26 12:13:36 +08:00

149 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
与 `get_llm_provider().langchain_llm` 配合使用的 LangChain Runnable 约定。
langchain-openai 要求用顶层 `response_format` 绑定 JSON 模式,禁止对 `.bind()` 传入
`model_kwargs={"response_format": ...}`(会错误传入底层 `completions.create`)。
"""
from __future__ import annotations
import hashlib
import time
from typing import Any
from app.core.agent_logging import (
agent_summary_enabled,
agent_verbose_enabled,
log_agent_payload,
)
from app.core.logging import get_logger
logger = get_logger(__name__)
def bind_json_object_mode(llm: Any, *, max_tokens: int) -> Any:
"""返回绑定 `response_format=json_object` 与 `max_tokens` 的 Runnable通常为 ChatOpenAI"""
return llm.bind(
response_format={"type": "json_object"},
max_tokens=max_tokens,
)
def _prompt_sha12(prompt: str) -> str:
return hashlib.sha256((prompt or "").encode("utf-8")).hexdigest()[:12]
def invoke_json_object(
llm: Any,
prompt: str,
*,
max_tokens: int,
agent: str | None = None,
retry_empty: bool = True,
) -> str:
"""
同步调用 JSON object 模式;空 content 时可选重试一次(缓解 DeepSeek 偶发空输出)。
仅依赖 bind_json_object_mode不引用 features。
"""
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
tag = agent or "json_object"
sha = _prompt_sha12(prompt)
attempts = 2 if retry_empty else 1
t0 = time.perf_counter()
last_content = ""
for attempt in range(attempts):
response = bound.invoke(prompt)
content = (getattr(response, "content", None) or "").strip()
last_content = content
if content:
if attempt > 0:
logger.info(
"json_object 空内容重试成功 agent={} prompt_sha12={}",
tag,
sha,
)
_log_json_object_done(
tag, sha, prompt, content, attempt + 1, t0, success=True
)
return content
if attempt == 0 and retry_empty:
logger.warning(
"json_object 返回空 content将重试 agent={} attempt={} prompt_sha12={}",
tag,
attempt,
sha,
)
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
_log_json_object_done(tag, sha, prompt, last_content, attempts, t0, success=False)
return ""
async def ainvoke_json_object(
llm: Any,
prompt: str,
*,
max_tokens: int,
agent: str | None = None,
retry_empty: bool = True,
) -> str:
"""异步版 `invoke_json_object`。"""
bound = bind_json_object_mode(llm, max_tokens=max_tokens)
tag = agent or "json_object"
sha = _prompt_sha12(prompt)
attempts = 2 if retry_empty else 1
t0 = time.perf_counter()
last_content = ""
for attempt in range(attempts):
response = await bound.ainvoke(prompt)
content = (getattr(response, "content", None) or "").strip()
last_content = content
if content:
if attempt > 0:
logger.info(
"json_object 空内容重试成功 agent={} prompt_sha12={}",
tag,
sha,
)
_log_json_object_done(
tag, sha, prompt, content, attempt + 1, t0, success=True
)
return content
if attempt == 0 and retry_empty:
logger.warning(
"json_object 返回空 content将重试 agent={} attempt={} prompt_sha12={}",
tag,
attempt,
sha,
)
logger.warning("json_object 仍为空 agent={} prompt_sha12={}", tag, sha)
_log_json_object_done(tag, sha, prompt, last_content, attempts, t0, success=False)
return ""
def _log_json_object_done(
tag: str,
sha: str,
prompt: str,
content: str,
attempts_used: int,
t0: float,
*,
success: bool,
) -> None:
ms = (time.perf_counter() - t0) * 1000
if agent_summary_enabled():
logger.info(
"llm_json_object agent={} prompt_sha12={} duration_ms={:.2f} "
"prompt_len={} response_len={} attempts={} success={}",
tag,
sha,
ms,
len(prompt or ""),
len(content or ""),
attempts_used,
success,
)
if agent_verbose_enabled():
log_agent_payload(logger, f"{tag}.prompt", prompt)
log_agent_payload(logger, f"{tag}.response", content)