"""聊天 Agent 共享工具:历史获取、格式化、存储""" from dataclasses import dataclass from datetime import datetime from typing import Any, List from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from app.core.redis import redis_service def _human_ai_rows(history: list[dict]) -> list[dict]: """与 get_history_messages 一致:仅保留 human/ai,顺序与 Redis 列表一致。""" return [m for m in history if m.get("role") in ("human", "ai")] def _lc_messages_from_rows(rows: list[dict]) -> list[HumanMessage | AIMessage]: out: list[HumanMessage | AIMessage] = [] for msg in rows: role = msg.get("role") if role == "human": out.append(HumanMessage(content=msg["content"])) elif role == "ai": out.append(AIMessage(content=msg["content"])) return out @dataclass(frozen=True) class HistoryWithWindow: """单次 Redis 读取后的全量轮次计数 + 截断后注入 LLM 的消息列表。""" turn_total: int window: list[HumanMessage | AIMessage] async def get_history_with_window( conversation_id: str, *, max_pairs: int, max_chars: int, ) -> HistoryWithWindow: """一次读取 Redis:turn_total 由全量 human/ai 条数得到;仅对窗口切片构造 LangChain 消息。""" history = await redis_service.get_conversation_history(conversation_id) human_ai = _human_ai_rows(history) turn_total = len(human_ai) // 2 window_raw = human_ai[-(max_pairs * 2) :] if max_pairs > 0 else human_ai[:] window = _lc_messages_from_rows(window_raw) total_chars = 0 start = len(window) for i in range(len(window) - 1, -1, -1): msg = window[i] content = getattr(msg, "content", "") or "" total_chars += len(content) if total_chars > max_chars: start = i + 1 break else: start = 0 if start < len(window) and isinstance(window[start], AIMessage): start += 1 trimmed = window[start:] return HistoryWithWindow(turn_total=turn_total, window=trimmed) async def get_history_messages(conversation_id: str) -> List[Any]: """从 Redis 获取对话历史""" history = await redis_service.get_conversation_history(conversation_id) return _lc_messages_from_rows(_human_ai_rows(history)) def format_history_string(messages: List[Any]) -> str: """将 LangChain 消息列表格式化为调试日志用多段文本(含 System,不静默跳过)。""" history_parts: list[str] = [] for msg in messages: if isinstance(msg, SystemMessage): history_parts.append(f"System: {msg.content}") elif isinstance(msg, HumanMessage): history_parts.append(f"Human: {msg.content}") elif isinstance(msg, AIMessage): history_parts.append(f"Assistant: {msg.content}") else: content = getattr(msg, "content", None) history_parts.append(f"{type(msg).__name__}: {content}") return "\n\n".join(history_parts) async def save_message( conversation_id: str, role: str, content: str, message_type: str = "text", voice_session_id: str | None = None, timestamp: datetime | str | int | None = None, audio_duration_seconds: int | None = None, ) -> None: """保存消息到 Redis""" await redis_service.add_message( conversation_id, role, content, message_type=message_type, voice_session_id=voice_session_id, timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp, audio_duration_seconds=audio_duration_seconds, )