Files
life-echo/api/app/agents/chat/helpers.py
2026-04-03 13:49:24 +08:00

125 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""聊天 Agent 共享工具:历史获取、格式化、存储"""
import hashlib
from dataclasses import dataclass
from datetime import datetime
from typing import Any, List
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from app.core.redis import redis_service
def _human_ai_rows(history: list[dict]) -> list[dict]:
"""与 get_history_messages 一致:仅保留 human/ai顺序与 Redis 列表一致。"""
return [m for m in history if m.get("role") in ("human", "ai")]
def _lc_messages_from_rows(rows: list[dict]) -> list[HumanMessage | AIMessage]:
out: list[HumanMessage | AIMessage] = []
for msg in rows:
role = msg.get("role")
if role == "human":
out.append(HumanMessage(content=msg["content"]))
elif role == "ai":
out.append(AIMessage(content=msg["content"]))
return out
@dataclass(frozen=True)
class HistoryWithWindow:
"""单次 Redis 读取后的全量轮次计数 + 截断后注入 LLM 的消息列表。"""
turn_total: int
window: list[HumanMessage | AIMessage]
async def get_history_with_window(
conversation_id: str,
*,
max_pairs: int,
max_chars: int,
) -> HistoryWithWindow:
"""一次读取 Redisturn_total 由全量 human/ai 条数得到;仅对窗口切片构造 LangChain 消息。"""
history = await redis_service.get_conversation_history(conversation_id)
human_ai = _human_ai_rows(history)
turn_total = len(human_ai) // 2
window_raw = human_ai[-(max_pairs * 2) :] if max_pairs > 0 else human_ai[:]
window = _lc_messages_from_rows(window_raw)
total_chars = 0
start = len(window)
for i in range(len(window) - 1, -1, -1):
msg = window[i]
content = getattr(msg, "content", "") or ""
total_chars += len(content)
if total_chars > max_chars:
start = i + 1
break
else:
start = 0
if start < len(window) and isinstance(window[start], AIMessage):
start += 1
trimmed = window[start:]
return HistoryWithWindow(turn_total=turn_total, window=trimmed)
async def get_history_messages(conversation_id: str) -> List[Any]:
"""从 Redis 获取对话历史"""
history = await redis_service.get_conversation_history(conversation_id)
return _lc_messages_from_rows(_human_ai_rows(history))
def _sha12_utf8(text: str) -> str:
return hashlib.sha256((text or "").encode("utf-8")).hexdigest()[:12]
def format_history_string(
messages: List[Any], *, omit_system_body: bool = False
) -> str:
"""将 LangChain 消息列表格式化为调试日志用多段文本(含 System不静默跳过"""
history_parts: list[str] = []
for msg in messages:
if isinstance(msg, SystemMessage):
if omit_system_body:
c = (
(msg.content or "")
if isinstance(msg.content, str)
else str(msg.content)
)
history_parts.append(
f"System: <omitted total_len={len(c)} sha12={_sha12_utf8(c)}>"
)
else:
history_parts.append(f"System: {msg.content}")
elif isinstance(msg, HumanMessage):
history_parts.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage):
history_parts.append(f"Assistant: {msg.content}")
else:
content = getattr(msg, "content", None)
history_parts.append(f"{type(msg).__name__}: {content}")
return "\n\n".join(history_parts)
async def save_message(
conversation_id: str,
role: str,
content: str,
message_type: str = "text",
voice_session_id: str | None = None,
timestamp: datetime | str | int | None = None,
audio_duration_seconds: int | None = None,
) -> None:
"""保存消息到 Redis"""
await redis_service.add_message(
conversation_id,
role,
content,
message_type=message_type,
voice_session_id=voice_session_id,
timestamp=timestamp.isoformat()
if isinstance(timestamp, datetime)
else timestamp,
audio_duration_seconds=audio_duration_seconds,
)