"""Redis 客户端与会话/缓存能力:供应用生命周期、会话历史、任务追踪等使用。 配置从 app.core.config.settings 读取,禁止业务层散落 os.getenv。 """ import json from datetime import datetime, timezone from typing import Any, Dict, List, Optional import redis.asyncio as aioredis from app.core.config import settings from app.core.logging import get_logger from app.core.runtime_constants import misc_defaults, redis_defaults logger = get_logger(__name__) class RedisService: """Redis 服务:连接管理、对话历史、通用缓存。""" def __init__(self) -> None: self.redis_url = settings.redis_url_resolved self._client: Optional[aioredis.Redis] = None self.session_ttl = misc_defaults.redis_session_ttl async def get_client(self) -> aioredis.Redis: """获取 Redis 客户端(延迟初始化)。""" if self._client is None: try: self._client = await aioredis.from_url( self.redis_url, encoding="utf-8", decode_responses=True, socket_timeout=redis_defaults.socket_timeout_seconds, socket_connect_timeout=redis_defaults.socket_connect_timeout_seconds, health_check_interval=redis_defaults.health_check_interval_seconds, retry_on_timeout=True, ) await self._client.ping() logger.info("Redis 连接成功") try: from urllib.parse import urlparse p = urlparse(self.redis_url) logger.debug( "Redis 连接 host={} port={}", p.hostname or "", p.port or "", ) except Exception: logger.debug("Redis 已连接(URL 解析省略)") except Exception as e: logger.error("Redis 连接失败: {}", e) raise return self._client async def close(self) -> None: """关闭 Redis 连接。""" if self._client: await self._client.close() self._client = None def _conversation_key(self, conversation_id: str) -> str: return f"conversation:history:{conversation_id}" async def _key_type(self, client: aioredis.Redis, key: str) -> str: key_type = await client.type(key) if isinstance(key_type, bytes): return key_type.decode("utf-8") return str(key_type) async def _parse_history_items(self, raw_items: List[str]) -> List[Dict[str, Any]]: history: List[Dict[str, Any]] = [] for raw in raw_items: try: parsed = json.loads(raw) except json.JSONDecodeError: logger.warning("跳过无效对话历史条目") continue if isinstance(parsed, dict): history.append(parsed) return history async def _migrate_string_history_to_list( self, client: aioredis.Redis, key: str, history: List[Dict[str, Any]] ) -> None: if not history: await client.delete(key) return pipe = client.pipeline(transaction=True) pipe.delete(key) for item in history: pipe.rpush(key, json.dumps(item, ensure_ascii=False)) pipe.expire(key, self.session_ttl) await pipe.execute() async def get_conversation_history( self, conversation_id: str ) -> List[Dict[str, Any]]: try: client = await self.get_client() key = self._conversation_key(conversation_id) if not await client.exists(key): return [] key_type = await self._key_type(client, key) if key_type == "list": raw_items = await client.lrange(key, 0, -1) return await self._parse_history_items(list(raw_items)) if key_type == "string": data = await client.get(key) if not data: return [] legacy = json.loads(data) if not isinstance(legacy, list): return [] history = [x for x in legacy if isinstance(x, dict)] await self._migrate_string_history_to_list(client, key, history) return history logger.warning( "conversation history unexpected type={} key={}", key_type, key, ) return [] except Exception as e: logger.error("获取对话历史失败: {}", e) return [] async def set_conversation_history( self, conversation_id: str, history: List[Dict[str, Any]] ) -> bool: """整表覆盖会话历史(用于从 DB 回填),应用 session_ttl。""" try: client = await self.get_client() key = self._conversation_key(conversation_id) pipe = client.pipeline(transaction=True) pipe.delete(key) for item in history: pipe.rpush(key, json.dumps(item, ensure_ascii=False)) pipe.expire(key, self.session_ttl) await pipe.execute() return True except Exception as e: logger.error("写入对话历史失败: {}", e) return False async def add_message( self, conversation_id: str, role: str, content: str, message_type: str = "text", voice_session_id: str | None = None, timestamp: str | int | None = None, audio_duration_seconds: int | None = None, ) -> bool: try: client = await self.get_client() key = self._conversation_key(conversation_id) item = { "role": role, "content": content, "messageType": message_type, "timestamp": timestamp or datetime.now(timezone.utc).isoformat(), } if voice_session_id: item["voiceSessionId"] = voice_session_id if ( audio_duration_seconds is not None and audio_duration_seconds > 0 and message_type == "audio" ): item["durationSeconds"] = int(audio_duration_seconds) pipe = client.pipeline(transaction=True) pipe.rpush(key, json.dumps(item, ensure_ascii=False)) pipe.expire(key, self.session_ttl) await pipe.execute() return True except Exception as e: logger.error("添加消息失败: {}", e) return False async def append_tts_audio_url_to_last_ai_message( self, conversation_id: str, url: str ) -> bool: """向最近一条 AI 消息的 ttsAudioUrls 追加 upload 返回的 canonical URL(非预签名)。客户端通过 GET /messages 等出口收到预签名 URL。""" if not url: return False try: client = await self.get_client() key = self._conversation_key(conversation_id) history = await self.get_conversation_history(conversation_id) target_index: int | None = None for i in range(len(history) - 1, -1, -1): if history[i].get("role") == "ai": existing = history[i].get("ttsAudioUrls") urls: List[str] = ( [x for x in existing if isinstance(x, str)] if isinstance(existing, list) else [] ) urls.append(url) history[i]["ttsAudioUrls"] = urls target_index = i break if target_index is None: logger.warning( "append_tts_audio_url: no ai message in history conversation_id={}", conversation_id, ) return False await client.lset( key, target_index, json.dumps(history[target_index], ensure_ascii=False), ) await client.expire(key, self.session_ttl) return True except Exception as e: logger.error("append_tts_audio_url 失败: {}", e) return False async def clear_conversation_history(self, conversation_id: str) -> bool: try: client = await self.get_client() key = self._conversation_key(conversation_id) await client.delete(key) return True except Exception as e: logger.error("清除对话历史失败: {}", e) return False async def delete_keys_matching_pattern(self, pattern: str) -> int: """按 SCAN 批量删除 key,避免阻塞式 KEYS *。""" try: client = await self.get_client() batch: list[str] = [] deleted = 0 async for key in client.scan_iter(match=pattern): batch.append(key) if len(batch) >= 200: deleted += int(await client.delete(*batch)) batch.clear() if batch: deleted += int(await client.delete(*batch)) return deleted except Exception as e: logger.error("按 pattern 删除 Redis key 失败: {}", e) return 0 async def extend_session_ttl(self, conversation_id: str) -> bool: try: client = await self.get_client() key = self._conversation_key(conversation_id) await client.expire(key, self.session_ttl) return True except Exception as e: logger.error("延长会话TTL失败: {}", e) return False async def set_cache(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: if ttl is None or ttl <= 0: logger.error("设置缓存失败: TTL 必须为正整数 key={}", key) return False try: client = await self.get_client() data = ( json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value ) await client.setex(key, ttl, data) return True except Exception as e: logger.error("设置缓存失败: {}", e) return False async def get_cache(self, key: str) -> Optional[Any]: try: client = await self.get_client() data = await client.get(key) if data: try: return json.loads(data) except json.JSONDecodeError: return data return None except Exception as e: logger.error("获取缓存失败: {}", e) return None async def delete_cache(self, key: str) -> bool: try: client = await self.get_client() await client.delete(key) return True except Exception as e: logger.error("删除缓存失败: {}", e) return False def is_available(self) -> bool: return self._client is not None # 全局单例,供 main 生命周期与各 feature 通过 get_redis_service 或直接引用使用 redis_service = RedisService()