1149 lines
52 KiB
Python
1149 lines
52 KiB
Python
"""
|
||
WebSocket 路由:实时对话通信
|
||
支持异步 Agent 调用和 Redis 会话存储
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from enum import Enum
|
||
from typing import Dict, List, Optional, Set, Tuple
|
||
|
||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
|
||
from starlette.websockets import WebSocketState
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from agents import ConversationAgent, MemoryAgent
|
||
from agents.memoir_processor import BackgroundTaskRunner
|
||
from database import get_async_db
|
||
from database.models import Conversation, Segment
|
||
from database.models import User as UserModel
|
||
from services.auth_service import verify_token
|
||
from services.memoir_state_service import get_or_create_state
|
||
from services import asr_service, redis_service
|
||
from agents.prompts.profile_prompts import format_user_profile_context
|
||
|
||
logger = logging.getLogger(__name__)
|
||
LEGACY_VOICE_SESSION_ID = "legacy"
|
||
|
||
|
||
class MessageType(str, Enum):
|
||
"""WebSocket 消息类型"""
|
||
CONNECT = "connect"
|
||
RECORDING_STARTED = "recording_started" # 客户端开始录音,用于服务端 5s 后发「我在认真听」
|
||
AUDIO_CHUNK = "audio_chunk"
|
||
AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传)
|
||
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
|
||
TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent,只返回转写结果
|
||
TEXT = "text" # 文本消息
|
||
TRANSCRIPT = "transcript" # 语音转文字结果
|
||
AGENT_RESPONSE = "agent_response"
|
||
TTS_AUDIO = "tts_audio"
|
||
END_CONVERSATION = "end_conversation"
|
||
MEMOIR_UPDATE = "memoir_update"
|
||
ERROR = "error"
|
||
|
||
|
||
# 连接管理
|
||
class ConnectionManager:
|
||
"""WebSocket 连接管理器"""
|
||
|
||
def __init__(self):
|
||
self.active_connections: Dict[str, WebSocket] = {}
|
||
self.segment_states: Dict[Tuple[str, str], "SegmentStreamState"] = {}
|
||
# ConversationAgent 现在是无状态的(会话存储在 Redis),可以复用
|
||
self.conversation_agent = ConversationAgent()
|
||
self.memory_agent = MemoryAgent()
|
||
self.background_runner = BackgroundTaskRunner()
|
||
|
||
async def connect(self, websocket: WebSocket, conversation_id: str):
|
||
"""建立连接"""
|
||
await websocket.accept()
|
||
self.active_connections[conversation_id] = websocket
|
||
|
||
async def disconnect(self, conversation_id: str):
|
||
"""断开连接"""
|
||
if conversation_id in self.active_connections:
|
||
del self.active_connections[conversation_id]
|
||
stale_keys = [
|
||
key
|
||
for key, state in self.segment_states.items()
|
||
if key[0] == conversation_id and not state.active_tasks
|
||
]
|
||
for key in stale_keys:
|
||
self.segment_states.pop(key, None)
|
||
# 清除 Redis 中的会话记忆(可选,也可以保留用于恢复)
|
||
# await self.conversation_agent.clear_memory(conversation_id)
|
||
|
||
def get_or_create_segment_state(
|
||
self,
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
) -> "SegmentStreamState":
|
||
state_key = (conversation_id, voice_session_id)
|
||
if state_key not in self.segment_states:
|
||
self.segment_states[state_key] = SegmentStreamState()
|
||
return self.segment_states[state_key]
|
||
|
||
def register_segment_task(
|
||
self,
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
task: asyncio.Task,
|
||
) -> None:
|
||
state_key = (conversation_id, voice_session_id)
|
||
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
|
||
state.active_tasks.add(task)
|
||
|
||
def _cleanup(done_task: asyncio.Task) -> None:
|
||
state.active_tasks.discard(done_task)
|
||
if not state.active_tasks and conversation_id not in self.active_connections:
|
||
self.segment_states.pop(state_key, None)
|
||
if done_task.cancelled():
|
||
return
|
||
exc = done_task.exception()
|
||
if exc:
|
||
logger.error(
|
||
"分段处理任务异常 "
|
||
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
|
||
exc_info=True,
|
||
)
|
||
|
||
task.add_done_callback(_cleanup)
|
||
|
||
async def send_message(self, conversation_id: str, message: dict):
|
||
"""发送消息"""
|
||
if conversation_id in self.active_connections:
|
||
websocket = self.active_connections[conversation_id]
|
||
try:
|
||
# 尝试发送消息,如果连接已关闭会抛出异常
|
||
await websocket.send_json(message)
|
||
except (RuntimeError, Exception) as e:
|
||
logger.warning(f"发送消息失败 (conversation_id={conversation_id}): {e}")
|
||
# 如果发送失败,从连接列表中移除
|
||
if conversation_id in self.active_connections:
|
||
del self.active_connections[conversation_id]
|
||
|
||
async def receive_message(self, conversation_id: str) -> dict:
|
||
"""接收消息"""
|
||
if conversation_id in self.active_connections:
|
||
websocket = self.active_connections[conversation_id]
|
||
return await websocket.receive_json()
|
||
raise HTTPException(status_code=404, detail="Connection not found")
|
||
|
||
|
||
manager = ConnectionManager()
|
||
|
||
|
||
@dataclass
|
||
class SegmentStreamState:
|
||
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
|
||
|
||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||
pending_indices: Set[int] = field(default_factory=set)
|
||
processed_indices: Set[int] = field(default_factory=set)
|
||
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
|
||
consumed_index: int = -1
|
||
active_tasks: Set[asyncio.Task] = field(default_factory=set)
|
||
# 录音开始约 5s 后只发一次「我在认真听」;若用户提前结束录音则取消待发
|
||
listening_feedback_sent: bool = False
|
||
listening_feedback_task: Optional[asyncio.Task] = None
|
||
|
||
|
||
def _utc_now() -> datetime:
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime:
|
||
activity_time = at or _utc_now()
|
||
conversation.last_message_at = activity_time
|
||
return activity_time
|
||
|
||
|
||
def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
|
||
if voice_session_id:
|
||
return str(voice_session_id)
|
||
return LEGACY_VOICE_SESSION_ID
|
||
|
||
|
||
def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]:
|
||
if not client_segment_id:
|
||
return None
|
||
session_id, separator, _ = client_segment_id.rpartition("-")
|
||
if separator and session_id:
|
||
return session_id
|
||
return None
|
||
|
||
|
||
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
|
||
"""构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。"""
|
||
return f"audio-segment:{voice_session_id}:{segment_index}"
|
||
|
||
|
||
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
|
||
"""从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}。"""
|
||
prefix = "audio-segment:"
|
||
if not audio_url or not audio_url.startswith(prefix):
|
||
return None
|
||
payload = audio_url[len(prefix):]
|
||
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
|
||
try:
|
||
if separator:
|
||
return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw))
|
||
return (LEGACY_VOICE_SESSION_ID, int(payload))
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]:
|
||
scope = _extract_segment_scope(audio_url)
|
||
if scope:
|
||
return scope[0]
|
||
return None
|
||
|
||
|
||
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
|
||
if not transcript_text:
|
||
return True
|
||
return transcript_text.startswith("转写失败")
|
||
|
||
|
||
async def _find_existing_segment_by_index(
|
||
db: AsyncSession,
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
segment_index: int,
|
||
) -> Optional[Segment]:
|
||
"""
|
||
按 conversation + voice_session_id + segment_index 查找已落库分段。
|
||
说明:测试桩的 execute() 不会真正执行 where,所以这里做一次 Python 侧过滤,兼容真实 DB 和单测桩。
|
||
"""
|
||
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
|
||
stmt = select(Segment).where(
|
||
Segment.conversation_id == conversation_id,
|
||
Segment.audio_url == segment_audio_url,
|
||
).order_by(Segment.created_at.desc())
|
||
result = await db.execute(stmt)
|
||
candidates = result.scalars().all()
|
||
for item in candidates:
|
||
if item.conversation_id == conversation_id and item.audio_url == segment_audio_url:
|
||
return item
|
||
return None
|
||
|
||
|
||
async def _get_persisted_contiguous_segment_index(
|
||
db: AsyncSession,
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
) -> int:
|
||
"""读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。"""
|
||
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
|
||
result = await db.execute(stmt)
|
||
candidates = result.scalars().all()
|
||
|
||
persisted_indices: Set[int] = set()
|
||
for item in candidates:
|
||
if item.conversation_id != conversation_id:
|
||
continue
|
||
segment_scope = _extract_segment_scope(item.audio_url)
|
||
if not segment_scope:
|
||
continue
|
||
item_voice_session_id, item_index = segment_scope
|
||
if item_voice_session_id != voice_session_id:
|
||
continue
|
||
persisted_indices.add(item_index)
|
||
|
||
contiguous_index = -1
|
||
while contiguous_index + 1 in persisted_indices:
|
||
contiguous_index += 1
|
||
return contiguous_index
|
||
|
||
|
||
LISTENING_FEEDBACK_DELAY_SEC = 5.0
|
||
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
|
||
|
||
|
||
async def _send_segment_transition_feedback(
|
||
conversation_id: str,
|
||
segment_index: int,
|
||
manager: ConnectionManager,
|
||
) -> None:
|
||
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": LISTENING_FEEDBACK_TEXT,
|
||
"transition": True,
|
||
"segment_index": segment_index,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
})
|
||
|
||
|
||
async def _delayed_listening_feedback(
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
manager: ConnectionManager,
|
||
) -> None:
|
||
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
|
||
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
|
||
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
|
||
async with state.lock:
|
||
if state.listening_feedback_sent:
|
||
return
|
||
state.listening_feedback_sent = True
|
||
state.listening_feedback_task = None
|
||
await _send_segment_transition_feedback(conversation_id, 0, manager)
|
||
|
||
|
||
async def _process_audio_segment_async(
|
||
conversation_id: str,
|
||
user_id: str,
|
||
voice_session_id: str,
|
||
segment_index: int,
|
||
audio_base64: str,
|
||
audio_duration: int,
|
||
is_last: bool,
|
||
manager: ConnectionManager,
|
||
) -> None:
|
||
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
|
||
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
|
||
|
||
try:
|
||
# 每个分段任务使用独立 DB Session,避免与主循环共享同一 AsyncSession 导致并发冲突。
|
||
async for db in get_async_db():
|
||
conversation = await db.get(Conversation, conversation_id)
|
||
user = await db.get(UserModel, user_id)
|
||
if not conversation:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "对话不存在,分段处理已取消"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
})
|
||
return
|
||
if not user:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "用户不存在,分段处理已取消"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
})
|
||
return
|
||
|
||
async with state.lock:
|
||
should_prime_state = (
|
||
state.consumed_index < 0
|
||
and not state.processed_indices
|
||
and not state.buffered_transcripts
|
||
)
|
||
|
||
if should_prime_state:
|
||
persisted_contiguous_index = await _get_persisted_contiguous_segment_index(
|
||
db=db,
|
||
conversation_id=conversation_id,
|
||
voice_session_id=voice_session_id,
|
||
)
|
||
if persisted_contiguous_index >= 0:
|
||
async with state.lock:
|
||
state.consumed_index = max(state.consumed_index, persisted_contiguous_index)
|
||
|
||
transcript_text = await asr_service.transcribe(audio_base64)
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.TRANSCRIPT,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": transcript_text or "",
|
||
"audio_duration": audio_duration,
|
||
"voice_session_id": voice_session_id,
|
||
"segment_index": segment_index,
|
||
"is_last": is_last,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
})
|
||
|
||
if _is_transcribe_failure(transcript_text):
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": f"分段 {segment_index} 转写失败,请重试该片段",
|
||
"segment_index": segment_index,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
})
|
||
return
|
||
|
||
existing_segment = await _find_existing_segment_by_index(
|
||
db=db,
|
||
conversation_id=conversation_id,
|
||
voice_session_id=voice_session_id,
|
||
segment_index=segment_index,
|
||
)
|
||
if existing_segment:
|
||
# 该分段已成功入库,视为重传:不重复入库、不重复触发 Agent。
|
||
async with state.lock:
|
||
state.processed_indices.add(segment_index)
|
||
logger.info(
|
||
"分段已存在,按幂等处理跳过: "
|
||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||
)
|
||
return
|
||
else:
|
||
segment = Segment(
|
||
id=str(uuid.uuid4()),
|
||
conversation_id=conversation_id,
|
||
transcript_text=transcript_text or "",
|
||
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
|
||
processed=False,
|
||
)
|
||
db.add(segment)
|
||
user_message_timestamp = _mark_conversation_active(conversation)
|
||
await db.commit()
|
||
await db.refresh(segment)
|
||
await manager.background_runner.queue_message(conversation.user_id, segment.id)
|
||
|
||
ready_segments: List[Tuple[int, str, Segment]] = []
|
||
async with state.lock:
|
||
state.processed_indices.add(segment_index)
|
||
state.buffered_transcripts[segment_index] = (transcript_text or "", segment)
|
||
|
||
next_index = state.consumed_index + 1
|
||
while next_index in state.buffered_transcripts:
|
||
text, seg = state.buffered_transcripts.pop(next_index)
|
||
ready_segments.append((next_index, text, seg))
|
||
state.consumed_index = next_index
|
||
next_index += 1
|
||
|
||
# 仅当前缀分段连续时才触发 Agent,保证增量上下文顺序正确。
|
||
for _, ordered_text, ordered_segment in ready_segments:
|
||
await process_user_message(
|
||
conversation_id=conversation_id,
|
||
user_message=ordered_text,
|
||
conversation=conversation,
|
||
segment=ordered_segment,
|
||
db=db,
|
||
manager=manager,
|
||
user=user,
|
||
user_message_timestamp=ordered_segment.created_at or user_message_timestamp,
|
||
)
|
||
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
|
||
exc_info=True,
|
||
)
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": f"分段处理失败: {str(e)}",
|
||
"segment_index": segment_index,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
})
|
||
finally:
|
||
async with state.lock:
|
||
state.pending_indices.discard(segment_index)
|
||
|
||
|
||
async def websocket_endpoint(
|
||
websocket: WebSocket,
|
||
conversation_id: str
|
||
):
|
||
"""
|
||
WebSocket 端点:处理实时对话
|
||
|
||
Args:
|
||
websocket: WebSocket 连接
|
||
conversation_id: 对话 ID
|
||
"""
|
||
# 从查询参数获取token
|
||
token = websocket.query_params.get("token")
|
||
if not token:
|
||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
|
||
return
|
||
|
||
# 验证JWT令牌
|
||
payload = verify_token(token)
|
||
if not payload or payload.get("type") != "access":
|
||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
|
||
return
|
||
|
||
user_id = payload.get("sub")
|
||
if not user_id:
|
||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
|
||
return
|
||
|
||
# 验证用户是否存在
|
||
async for db in get_async_db():
|
||
user = await db.get(UserModel, user_id)
|
||
if not user:
|
||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
|
||
return
|
||
|
||
await manager.connect(websocket, conversation_id)
|
||
|
||
try:
|
||
# 发送连接确认
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.CONNECT,
|
||
"conversation_id": conversation_id,
|
||
"data": {"status": "connected"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
|
||
# 从数据库获取对话信息
|
||
conversation = await db.get(Conversation, conversation_id)
|
||
if not conversation:
|
||
# 如果对话不存在,创建新对话
|
||
conversation = Conversation(
|
||
id=conversation_id,
|
||
user_id=user_id,
|
||
started_at=datetime.now(timezone.utc),
|
||
status="active"
|
||
)
|
||
db.add(conversation)
|
||
await db.commit()
|
||
else:
|
||
# 验证用户权限:只能访问自己的对话
|
||
if conversation.user_id != user_id:
|
||
try:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "无权访问此对话"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
except Exception:
|
||
pass # 如果发送失败,直接关闭连接
|
||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
|
||
return
|
||
|
||
|
||
# 首次连接时检查:若 Redis 已有历史(用户曾进入过此对话),不再发送开场白,避免重复/自问自答
|
||
history = await redis_service.get_conversation_history(conversation_id)
|
||
if not history:
|
||
# 空对话:发送开场白(资料收集或正式访谈)
|
||
missing_profile = _get_missing_profile_fields(user)
|
||
if missing_profile:
|
||
try:
|
||
greetings = await manager.conversation_agent.generate_profile_greeting(
|
||
conversation_id=conversation_id,
|
||
missing_fields=missing_profile,
|
||
nickname=user.nickname or "",
|
||
)
|
||
import asyncio as _asyncio_greet
|
||
for i, text in enumerate(greetings):
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {"text": text, "index": i, "total": len(greetings)},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
if i < len(greetings) - 1:
|
||
await _asyncio_greet.sleep(0.5)
|
||
except Exception as e:
|
||
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
|
||
else:
|
||
# 资料已完整:AI 先开口提问
|
||
try:
|
||
state = await get_or_create_state(user_id, db)
|
||
user_profile_context = format_user_profile_context(
|
||
birth_year=user.birth_year,
|
||
birth_place=user.birth_place,
|
||
grew_up_place=user.grew_up_place,
|
||
occupation=user.occupation,
|
||
)
|
||
opening_messages = await manager.conversation_agent.generate_opening_message(
|
||
conversation_id=conversation_id,
|
||
memoir_state=state,
|
||
user_profile_context=user_profile_context,
|
||
)
|
||
import asyncio as _asyncio_open
|
||
for i, text in enumerate(opening_messages):
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {"text": text, "index": i, "total": len(opening_messages)},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
if i < len(opening_messages) - 1:
|
||
await _asyncio_open.sleep(0.5)
|
||
except Exception as e:
|
||
logger.error(f"发送空对话开场白失败: {e}", exc_info=True)
|
||
|
||
# 主循环:处理消息
|
||
while True:
|
||
try:
|
||
if websocket.application_state != WebSocketState.CONNECTED:
|
||
logger.info(f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}")
|
||
break
|
||
message = await websocket.receive_json()
|
||
msg_type = message.get("type")
|
||
|
||
if msg_type == MessageType.TEXT:
|
||
# 处理文本消息
|
||
text_message = message.get("data", {}).get("text", "")
|
||
|
||
if text_message:
|
||
# 校验对话轮数配额
|
||
from routers.quota import get_segment_count, check_can_send_message
|
||
seg_count = await get_segment_count(user_id, db)
|
||
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
|
||
if not can_send:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
continue
|
||
|
||
# 保存段落到数据库
|
||
segment = Segment(
|
||
id=str(uuid.uuid4()),
|
||
conversation_id=conversation_id,
|
||
transcript_text=text_message,
|
||
processed=False
|
||
)
|
||
db.add(segment)
|
||
user_message_timestamp = _mark_conversation_active(conversation)
|
||
await db.commit()
|
||
await db.refresh(segment)
|
||
await manager.background_runner.queue_message(conversation.user_id, segment.id)
|
||
|
||
# Agent 生成回应
|
||
await process_user_message(
|
||
conversation_id=conversation_id,
|
||
user_message=text_message,
|
||
conversation=conversation,
|
||
segment=segment,
|
||
db=db,
|
||
manager=manager,
|
||
user=user,
|
||
user_message_timestamp=segment.created_at or user_message_timestamp,
|
||
)
|
||
|
||
elif msg_type == MessageType.RECORDING_STARTED:
|
||
# 用户点击开始录音:启动 5s 定时器,到时发一次「我在认真听」
|
||
data = message.get("data", {})
|
||
voice_session_id = _normalize_voice_session_id(data.get("voice_session_id"))
|
||
segment_state = manager.get_or_create_segment_state(
|
||
conversation_id,
|
||
voice_session_id,
|
||
)
|
||
async with segment_state.lock:
|
||
if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done():
|
||
continue # 本会话已有待发任务,不重复
|
||
if segment_state.listening_feedback_sent:
|
||
continue
|
||
delayed_task = asyncio.create_task(
|
||
_delayed_listening_feedback(
|
||
conversation_id=conversation_id,
|
||
voice_session_id=voice_session_id,
|
||
manager=manager,
|
||
)
|
||
)
|
||
segment_state.listening_feedback_task = delayed_task
|
||
|
||
elif msg_type == MessageType.AUDIO_SEGMENT:
|
||
# 处理分段语音消息(长语音持续上传)
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
segment_index_raw = data.get("segment_index")
|
||
voice_session_id = _normalize_voice_session_id(
|
||
data.get("voice_session_id")
|
||
or _voice_session_id_from_client_segment_id(data.get("client_segment_id"))
|
||
)
|
||
is_last = bool(data.get("is_last", False))
|
||
audio_duration = int(data.get("duration", 0) or 0)
|
||
|
||
if not audio_base64:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 audio_base64"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
continue
|
||
|
||
if segment_index_raw is None:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 segment_index"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
continue
|
||
|
||
try:
|
||
segment_index = int(segment_index_raw)
|
||
except (TypeError, ValueError):
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "segment_index 必须为整数"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
continue
|
||
|
||
if segment_index < 0:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "segment_index 不能为负数"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
continue
|
||
|
||
# 校验对话轮数配额(分段也计入对话轮次)
|
||
from routers.quota import get_segment_count, check_can_send_message
|
||
seg_count = await get_segment_count(user_id, db)
|
||
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
|
||
if not can_send:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
continue
|
||
|
||
segment_state = manager.get_or_create_segment_state(
|
||
conversation_id,
|
||
voice_session_id,
|
||
)
|
||
should_process = False
|
||
async with segment_state.lock:
|
||
already_seen = (
|
||
segment_index in segment_state.pending_indices
|
||
or segment_index in segment_state.processed_indices
|
||
or segment_index <= segment_state.consumed_index
|
||
)
|
||
if not already_seen:
|
||
segment_state.pending_indices.add(segment_index)
|
||
should_process = True
|
||
|
||
if not should_process:
|
||
logger.info(
|
||
"收到重复分段,跳过处理: "
|
||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||
)
|
||
continue
|
||
|
||
# 若本段是用户结束录音的最后一段,取消尚未发出的「我在认真听」,避免结束后再说
|
||
if is_last:
|
||
async with segment_state.lock:
|
||
t = segment_state.listening_feedback_task
|
||
segment_state.listening_feedback_task = None
|
||
if t is not None and not t.done():
|
||
t.cancel()
|
||
|
||
task = asyncio.create_task(
|
||
_process_audio_segment_async(
|
||
conversation_id=conversation_id,
|
||
user_id=user_id,
|
||
voice_session_id=voice_session_id,
|
||
segment_index=segment_index,
|
||
audio_base64=audio_base64,
|
||
audio_duration=audio_duration,
|
||
is_last=is_last,
|
||
manager=manager,
|
||
)
|
||
)
|
||
manager.register_segment_task(conversation_id, voice_session_id, task)
|
||
|
||
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||
# 处理完整音频消息(类似微信语音)
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
audio_duration = data.get("duration", 0)
|
||
|
||
if audio_base64:
|
||
# 校验对话轮数配额
|
||
from routers.quota import get_segment_count, check_can_send_message
|
||
seg_count = await get_segment_count(user_id, db)
|
||
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
|
||
if not can_send:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
continue
|
||
|
||
logger.info(f"收到音频消息,时长: {audio_duration}s")
|
||
|
||
try:
|
||
# 1. ASR 转写
|
||
transcript_text = await asr_service.transcribe(audio_base64)
|
||
logger.info(f"ASR 转写结果: {transcript_text}")
|
||
|
||
# 2. 发送转写结果给客户端
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.TRANSCRIPT,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": transcript_text,
|
||
"audio_duration": audio_duration
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
|
||
# 3. 保存段落到数据库(包含转写文本和音频信息)
|
||
segment = Segment(
|
||
id=str(uuid.uuid4()),
|
||
conversation_id=conversation_id,
|
||
transcript_text=transcript_text,
|
||
audio_url=f"audio:{audio_duration}s", # 简化存储,标记为音频消息
|
||
processed=False
|
||
)
|
||
db.add(segment)
|
||
user_message_timestamp = _mark_conversation_active(conversation)
|
||
await db.commit()
|
||
await db.refresh(segment)
|
||
await manager.background_runner.queue_message(conversation.user_id, segment.id)
|
||
|
||
# 4. Agent 生成回应(基于转写文本)
|
||
if transcript_text and not transcript_text.startswith("转写失败"):
|
||
await process_user_message(
|
||
conversation_id=conversation_id,
|
||
user_message=transcript_text,
|
||
conversation=conversation,
|
||
segment=segment,
|
||
db=db,
|
||
manager=manager,
|
||
user=user,
|
||
user_message_timestamp=segment.created_at or user_message_timestamp,
|
||
)
|
||
else:
|
||
# 转写失败,发送错误消息
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "语音转写失败,请重试或使用文字输入"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理音频消息失败: {e}", exc_info=True)
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": f"处理音频消息失败: {str(e)}"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
|
||
elif msg_type == MessageType.TRANSCRIBE_ONLY:
|
||
# 仅转写:不落库、不触发 Agent,只把识别结果返回给客户端
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
if not audio_base64:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 audio_base64"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
continue
|
||
try:
|
||
transcript_text = await asr_service.transcribe(audio_base64)
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.TRANSCRIPT,
|
||
"conversation_id": conversation_id,
|
||
"data": {"text": transcript_text or ""},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"仅转写失败: {e}", exc_info=True)
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": f"转写失败: {str(e)}"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
|
||
elif msg_type == MessageType.END_CONVERSATION:
|
||
# 结束对话
|
||
conversation.status = "ended"
|
||
conversation.ended_at = datetime.now(timezone.utc)
|
||
await db.commit()
|
||
|
||
# 触发整理 Agent
|
||
await process_conversation_segments(conversation_id, db)
|
||
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.END_CONVERSATION,
|
||
"conversation_id": conversation_id,
|
||
"data": {"status": "ended"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
break
|
||
|
||
except RuntimeError as e:
|
||
# 检查是否是断开连接或未连接状态(如 accept 前/后连接被关闭)
|
||
error_msg = str(e)
|
||
if (
|
||
"disconnect" in error_msg.lower()
|
||
or "Cannot call \"receive\"" in error_msg
|
||
or "accept" in error_msg.lower() and "not connected" in error_msg.lower()
|
||
):
|
||
logger.info(f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}")
|
||
break
|
||
else:
|
||
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
|
||
# 只在连接仍然活跃时发送错误消息
|
||
if conversation_id in manager.active_connections:
|
||
try:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": str(e)},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
except Exception as send_error:
|
||
logger.warning(f"发送错误消息失败: {send_error}")
|
||
break
|
||
except WebSocketDisconnect:
|
||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
|
||
# 只在连接仍然活跃时发送错误消息
|
||
if conversation_id in manager.active_connections:
|
||
try:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": str(e)},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
except Exception as send_error:
|
||
logger.warning(f"发送错误消息失败: {send_error}")
|
||
break
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||
await manager.disconnect(conversation_id)
|
||
except Exception as e:
|
||
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
|
||
await manager.disconnect(conversation_id)
|
||
finally:
|
||
# 确保清理连接
|
||
await manager.disconnect(conversation_id)
|
||
|
||
|
||
def _get_missing_profile_fields(user: UserModel) -> list:
|
||
"""检查用户缺失的资料字段"""
|
||
from agents.prompts.profile_prompts import get_missing_profile_fields
|
||
return get_missing_profile_fields(
|
||
birth_year=user.birth_year,
|
||
birth_place=user.birth_place,
|
||
grew_up_place=user.grew_up_place,
|
||
occupation=user.occupation,
|
||
)
|
||
|
||
|
||
def _get_filled_profile_fields(user: UserModel) -> dict:
|
||
"""获取用户已有的资料字段(中文展示)"""
|
||
from agents.prompts.profile_prompts import PROFILE_FIELD_NAMES
|
||
filled = {}
|
||
if user.birth_year:
|
||
filled["birth_year"] = str(user.birth_year)
|
||
if user.birth_place:
|
||
filled["birth_place"] = user.birth_place
|
||
if user.grew_up_place:
|
||
filled["grew_up_place"] = user.grew_up_place
|
||
if user.occupation:
|
||
filled["occupation"] = user.occupation
|
||
return filled
|
||
|
||
|
||
async def _apply_extracted_profile(user: UserModel, extracted: dict, db: AsyncSession):
|
||
"""将提取到的资料信息保存到用户模型"""
|
||
changed = False
|
||
if "birth_year" in extracted and not user.birth_year:
|
||
user.birth_year = extracted["birth_year"]
|
||
changed = True
|
||
if "birth_place" in extracted and not user.birth_place:
|
||
user.birth_place = extracted["birth_place"]
|
||
changed = True
|
||
if "grew_up_place" in extracted and not user.grew_up_place:
|
||
user.grew_up_place = extracted["grew_up_place"]
|
||
changed = True
|
||
if "occupation" in extracted and not user.occupation:
|
||
user.occupation = extracted["occupation"]
|
||
changed = True
|
||
if changed:
|
||
await db.commit()
|
||
await db.refresh(user)
|
||
|
||
|
||
async def process_user_message(
|
||
conversation_id: str,
|
||
user_message: str,
|
||
conversation: Conversation,
|
||
segment: Segment,
|
||
db: AsyncSession,
|
||
manager: ConnectionManager,
|
||
user: UserModel = None,
|
||
user_message_timestamp: Optional[datetime] = None,
|
||
) -> None:
|
||
"""
|
||
处理用户消息,生成Agent回应(异步版本)
|
||
支持资料收集模式和正式访谈模式
|
||
"""
|
||
import asyncio as _asyncio
|
||
|
||
agent = manager.conversation_agent
|
||
|
||
# --- 资料收集模式 ---
|
||
if user:
|
||
missing = _get_missing_profile_fields(user)
|
||
if missing:
|
||
try:
|
||
extracted = await agent.extract_profile_from_message(
|
||
user_message, missing, conversation_id=conversation_id
|
||
)
|
||
if extracted:
|
||
await _apply_extracted_profile(user, extracted, db)
|
||
|
||
remaining = _get_missing_profile_fields(user)
|
||
filled = _get_filled_profile_fields(user)
|
||
is_from_voice = bool(segment.audio_url)
|
||
responses = await agent.generate_profile_followup(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
missing_fields=remaining,
|
||
filled_fields=filled,
|
||
nickname=user.nickname or "",
|
||
is_from_voice=is_from_voice,
|
||
voice_session_id=_voice_session_id_from_audio_url(segment.audio_url),
|
||
user_message_timestamp=user_message_timestamp,
|
||
)
|
||
|
||
segment.agent_response = "\n\n".join(responses)
|
||
_mark_conversation_active(conversation)
|
||
await db.commit()
|
||
|
||
for i, response_text in enumerate(responses):
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
if i < len(responses) - 1:
|
||
await _asyncio.sleep(0.5)
|
||
return
|
||
except Exception as e:
|
||
logger.error(f"资料收集处理失败: {e}", exc_info=True)
|
||
|
||
# --- 正式访谈模式 ---
|
||
state = await get_or_create_state(conversation.user_id, db)
|
||
|
||
if conversation.conversation_stage != state.current_stage:
|
||
conversation.conversation_stage = state.current_stage
|
||
await db.commit()
|
||
|
||
stmt_segments = select(Segment).where(
|
||
Segment.conversation_id == conversation_id
|
||
).order_by(Segment.created_at)
|
||
result_segments = await db.execute(stmt_segments)
|
||
previous_segments = result_segments.scalars().all()
|
||
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
||
|
||
# 构建用户资料上下文
|
||
user_profile_context = ""
|
||
if user:
|
||
from agents.prompts.profile_prompts import format_user_profile_context
|
||
user_profile_context = format_user_profile_context(
|
||
birth_year=user.birth_year,
|
||
birth_place=user.birth_place,
|
||
grew_up_place=user.grew_up_place,
|
||
occupation=user.occupation,
|
||
)
|
||
|
||
try:
|
||
is_from_voice = bool(segment.audio_url)
|
||
responses = await agent.generate_response_with_state(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
memoir_state=state,
|
||
user_profile_context=user_profile_context,
|
||
is_from_voice=is_from_voice,
|
||
voice_session_id=_voice_session_id_from_audio_url(segment.audio_url),
|
||
user_message_timestamp=user_message_timestamp,
|
||
)
|
||
|
||
segment.agent_response = "\n\n".join(responses)
|
||
_mark_conversation_active(conversation)
|
||
await db.commit()
|
||
|
||
for i, response_text in enumerate(responses):
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
if i < len(responses) - 1:
|
||
await _asyncio.sleep(0.5)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理用户消息失败: {e}", exc_info=True)
|
||
if conversation_id in manager.active_connections:
|
||
try:
|
||
await manager.send_message(conversation_id, {
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": f"生成回应失败: {str(e)}"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||
})
|
||
except Exception as send_error:
|
||
logger.warning(f"发送错误消息失败: {send_error}")
|
||
|
||
|
||
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||
"""
|
||
处理对话段落,生成章节(对话结束时调用)
|
||
|
||
注意:大部分处理已通过 Celery 任务增量完成
|
||
这里立即提交所有待处理的段落到 Celery
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
db: 数据库会话
|
||
"""
|
||
# 获取对话信息
|
||
conversation = await db.get(Conversation, conversation_id)
|
||
if not conversation:
|
||
return
|
||
|
||
# 获取所有未处理的段落
|
||
stmt = select(Segment).where(
|
||
Segment.conversation_id == conversation_id,
|
||
Segment.processed == False
|
||
)
|
||
result = await db.execute(stmt)
|
||
segments = result.scalars().all()
|
||
|
||
if not segments:
|
||
# 没有未处理的段落,直接 flush 待处理任务
|
||
await manager.background_runner.flush_pending(conversation.user_id)
|
||
return
|
||
|
||
# 免费版仅允许 1 个章节整理,提交前校验
|
||
from database.models import User as UserModel
|
||
from routers.quota import get_chapter_count, check_can_submit_organize
|
||
user = await db.get(UserModel, conversation.user_id)
|
||
if user:
|
||
chapter_count = await get_chapter_count(user.id, db)
|
||
can_submit, _ = check_can_submit_organize(user.subscription_type, chapter_count)
|
||
if not can_submit:
|
||
logger.info(
|
||
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
|
||
)
|
||
await manager.background_runner.flush_pending(conversation.user_id)
|
||
return
|
||
|
||
# 将未处理的段落直接提交到 Celery(不通过去抖)
|
||
segment_ids = [seg.id for seg in segments]
|
||
try:
|
||
from tasks.memoir_tasks import process_memoir_segments
|
||
process_memoir_segments.delay(conversation.user_id, segment_ids)
|
||
logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}")
|
||
except Exception as e:
|
||
logger.error(f"提交 Celery 任务失败: {e}")
|
||
|
||
# 同时 flush 任何待处理的任务
|
||
await manager.background_runner.flush_pending(conversation.user_id)
|