125 lines
4.1 KiB
Python
125 lines
4.1 KiB
Python
"""聊天 Agent 共享工具:历史获取、格式化、存储"""
|
||
|
||
import hashlib
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from typing import Any, List
|
||
|
||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||
|
||
from app.core.redis import redis_service
|
||
|
||
|
||
def _human_ai_rows(history: list[dict]) -> list[dict]:
|
||
"""与 get_history_messages 一致:仅保留 human/ai,顺序与 Redis 列表一致。"""
|
||
return [m for m in history if m.get("role") in ("human", "ai")]
|
||
|
||
|
||
def _lc_messages_from_rows(rows: list[dict]) -> list[HumanMessage | AIMessage]:
|
||
out: list[HumanMessage | AIMessage] = []
|
||
for msg in rows:
|
||
role = msg.get("role")
|
||
if role == "human":
|
||
out.append(HumanMessage(content=msg["content"]))
|
||
elif role == "ai":
|
||
out.append(AIMessage(content=msg["content"]))
|
||
return out
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class HistoryWithWindow:
|
||
"""单次 Redis 读取后的全量轮次计数 + 截断后注入 LLM 的消息列表。"""
|
||
|
||
turn_total: int
|
||
window: list[HumanMessage | AIMessage]
|
||
|
||
|
||
async def get_history_with_window(
|
||
conversation_id: str,
|
||
*,
|
||
max_pairs: int,
|
||
max_chars: int,
|
||
) -> HistoryWithWindow:
|
||
"""一次读取 Redis:turn_total 由全量 human/ai 条数得到;仅对窗口切片构造 LangChain 消息。"""
|
||
history = await redis_service.get_conversation_history(conversation_id)
|
||
human_ai = _human_ai_rows(history)
|
||
turn_total = len(human_ai) // 2
|
||
window_raw = human_ai[-(max_pairs * 2) :] if max_pairs > 0 else human_ai[:]
|
||
window = _lc_messages_from_rows(window_raw)
|
||
total_chars = 0
|
||
start = len(window)
|
||
for i in range(len(window) - 1, -1, -1):
|
||
msg = window[i]
|
||
content = getattr(msg, "content", "") or ""
|
||
total_chars += len(content)
|
||
if total_chars > max_chars:
|
||
start = i + 1
|
||
break
|
||
else:
|
||
start = 0
|
||
if start < len(window) and isinstance(window[start], AIMessage):
|
||
start += 1
|
||
trimmed = window[start:]
|
||
return HistoryWithWindow(turn_total=turn_total, window=trimmed)
|
||
|
||
|
||
async def get_history_messages(conversation_id: str) -> List[Any]:
|
||
"""从 Redis 获取对话历史"""
|
||
history = await redis_service.get_conversation_history(conversation_id)
|
||
return _lc_messages_from_rows(_human_ai_rows(history))
|
||
|
||
|
||
def _sha12_utf8(text: str) -> str:
|
||
return hashlib.sha256((text or "").encode("utf-8")).hexdigest()[:12]
|
||
|
||
|
||
def format_history_string(
|
||
messages: List[Any], *, omit_system_body: bool = False
|
||
) -> str:
|
||
"""将 LangChain 消息列表格式化为调试日志用多段文本(含 System,不静默跳过)。"""
|
||
history_parts: list[str] = []
|
||
for msg in messages:
|
||
if isinstance(msg, SystemMessage):
|
||
if omit_system_body:
|
||
c = (
|
||
(msg.content or "")
|
||
if isinstance(msg.content, str)
|
||
else str(msg.content)
|
||
)
|
||
history_parts.append(
|
||
f"System: <omitted total_len={len(c)} sha12={_sha12_utf8(c)}>"
|
||
)
|
||
else:
|
||
history_parts.append(f"System: {msg.content}")
|
||
elif isinstance(msg, HumanMessage):
|
||
history_parts.append(f"Human: {msg.content}")
|
||
elif isinstance(msg, AIMessage):
|
||
history_parts.append(f"Assistant: {msg.content}")
|
||
else:
|
||
content = getattr(msg, "content", None)
|
||
history_parts.append(f"{type(msg).__name__}: {content}")
|
||
return "\n\n".join(history_parts)
|
||
|
||
|
||
async def save_message(
|
||
conversation_id: str,
|
||
role: str,
|
||
content: str,
|
||
message_type: str = "text",
|
||
voice_session_id: str | None = None,
|
||
timestamp: datetime | str | int | None = None,
|
||
audio_duration_seconds: int | None = None,
|
||
) -> None:
|
||
"""保存消息到 Redis"""
|
||
await redis_service.add_message(
|
||
conversation_id,
|
||
role,
|
||
content,
|
||
message_type=message_type,
|
||
voice_session_id=voice_session_id,
|
||
timestamp=timestamp.isoformat()
|
||
if isinstance(timestamp, datetime)
|
||
else timestamp,
|
||
audio_duration_seconds=audio_duration_seconds,
|
||
)
|