Files
life-echo/api/routers/websocket.py

1149 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WebSocket 路由:实时对话通信
支持异步 Agent 调用和 Redis 会话存储
"""
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
from starlette.websockets import WebSocketState
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from agents import ConversationAgent, MemoryAgent
from agents.memoir_processor import BackgroundTaskRunner
from database import get_async_db
from database.models import Conversation, Segment
from database.models import User as UserModel
from services.auth_service import verify_token
from services.memoir_state_service import get_or_create_state
from services import asr_service, redis_service
from agents.prompts.profile_prompts import format_user_profile_context
logger = logging.getLogger(__name__)
LEGACY_VOICE_SESSION_ID = "legacy"
class MessageType(str, Enum):
"""WebSocket 消息类型"""
CONNECT = "connect"
RECORDING_STARTED = "recording_started" # 客户端开始录音,用于服务端 5s 后发「我在认真听」
AUDIO_CHUNK = "audio_chunk"
AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传)
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent只返回转写结果
TEXT = "text" # 文本消息
TRANSCRIPT = "transcript" # 语音转文字结果
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
END_CONVERSATION = "end_conversation"
MEMOIR_UPDATE = "memoir_update"
ERROR = "error"
# 连接管理
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.segment_states: Dict[Tuple[str, str], "SegmentStreamState"] = {}
# ConversationAgent 现在是无状态的(会话存储在 Redis可以复用
self.conversation_agent = ConversationAgent()
self.memory_agent = MemoryAgent()
self.background_runner = BackgroundTaskRunner()
async def connect(self, websocket: WebSocket, conversation_id: str):
"""建立连接"""
await websocket.accept()
self.active_connections[conversation_id] = websocket
async def disconnect(self, conversation_id: str):
"""断开连接"""
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
stale_keys = [
key
for key, state in self.segment_states.items()
if key[0] == conversation_id and not state.active_tasks
]
for key in stale_keys:
self.segment_states.pop(key, None)
# 清除 Redis 中的会话记忆(可选,也可以保留用于恢复)
# await self.conversation_agent.clear_memory(conversation_id)
def get_or_create_segment_state(
self,
conversation_id: str,
voice_session_id: str,
) -> "SegmentStreamState":
state_key = (conversation_id, voice_session_id)
if state_key not in self.segment_states:
self.segment_states[state_key] = SegmentStreamState()
return self.segment_states[state_key]
def register_segment_task(
self,
conversation_id: str,
voice_session_id: str,
task: asyncio.Task,
) -> None:
state_key = (conversation_id, voice_session_id)
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
state.active_tasks.add(task)
def _cleanup(done_task: asyncio.Task) -> None:
state.active_tasks.discard(done_task)
if not state.active_tasks and conversation_id not in self.active_connections:
self.segment_states.pop(state_key, None)
if done_task.cancelled():
return
exc = done_task.exception()
if exc:
logger.error(
"分段处理任务异常 "
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
exc_info=True,
)
task.add_done_callback(_cleanup)
async def send_message(self, conversation_id: str, message: dict):
"""发送消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
try:
# 尝试发送消息,如果连接已关闭会抛出异常
await websocket.send_json(message)
except (RuntimeError, Exception) as e:
logger.warning(f"发送消息失败 (conversation_id={conversation_id}): {e}")
# 如果发送失败,从连接列表中移除
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
async def receive_message(self, conversation_id: str) -> dict:
"""接收消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
return await websocket.receive_json()
raise HTTPException(status_code=404, detail="Connection not found")
manager = ConnectionManager()
@dataclass
class SegmentStreamState:
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
pending_indices: Set[int] = field(default_factory=set)
processed_indices: Set[int] = field(default_factory=set)
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
consumed_index: int = -1
active_tasks: Set[asyncio.Task] = field(default_factory=set)
# 录音开始约 5s 后只发一次「我在认真听」;若用户提前结束录音则取消待发
listening_feedback_sent: bool = False
listening_feedback_task: Optional[asyncio.Task] = None
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime:
activity_time = at or _utc_now()
conversation.last_message_at = activity_time
return activity_time
def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
if voice_session_id:
return str(voice_session_id)
return LEGACY_VOICE_SESSION_ID
def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]:
if not client_segment_id:
return None
session_id, separator, _ = client_segment_id.rpartition("-")
if separator and session_id:
return session_id
return None
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
"""构建分段语音的幂等标识conversation_id + voice_session_id + segment_index"""
return f"audio-segment:{voice_session_id}:{segment_index}"
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
"""从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}"""
prefix = "audio-segment:"
if not audio_url or not audio_url.startswith(prefix):
return None
payload = audio_url[len(prefix):]
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
try:
if separator:
return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw))
return (LEGACY_VOICE_SESSION_ID, int(payload))
except ValueError:
return None
def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]:
scope = _extract_segment_scope(audio_url)
if scope:
return scope[0]
return None
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
if not transcript_text:
return True
return transcript_text.startswith("转写失败")
async def _find_existing_segment_by_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
segment_index: int,
) -> Optional[Segment]:
"""
按 conversation + voice_session_id + segment_index 查找已落库分段。
说明:测试桩的 execute() 不会真正执行 where所以这里做一次 Python 侧过滤,兼容真实 DB 和单测桩。
"""
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.audio_url == segment_audio_url,
).order_by(Segment.created_at.desc())
result = await db.execute(stmt)
candidates = result.scalars().all()
for item in candidates:
if item.conversation_id == conversation_id and item.audio_url == segment_audio_url:
return item
return None
async def _get_persisted_contiguous_segment_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
) -> int:
"""读取数据库中当前 voice session 已连续落库的最大 segment_index用于重连恢复。"""
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
result = await db.execute(stmt)
candidates = result.scalars().all()
persisted_indices: Set[int] = set()
for item in candidates:
if item.conversation_id != conversation_id:
continue
segment_scope = _extract_segment_scope(item.audio_url)
if not segment_scope:
continue
item_voice_session_id, item_index = segment_scope
if item_voice_session_id != voice_session_id:
continue
persisted_indices.add(item_index)
contiguous_index = -1
while contiguous_index + 1 in persisted_indices:
contiguous_index += 1
return contiguous_index
LISTENING_FEEDBACK_DELAY_SEC = 5.0
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
async def _send_segment_transition_feedback(
conversation_id: str,
segment_index: int,
manager: ConnectionManager,
) -> None:
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": LISTENING_FEEDBACK_TEXT,
"transition": True,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
async def _delayed_listening_feedback(
conversation_id: str,
voice_session_id: str,
manager: ConnectionManager,
) -> None:
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
async with state.lock:
if state.listening_feedback_sent:
return
state.listening_feedback_sent = True
state.listening_feedback_task = None
await _send_segment_transition_feedback(conversation_id, 0, manager)
async def _process_audio_segment_async(
conversation_id: str,
user_id: str,
voice_session_id: str,
segment_index: int,
audio_base64: str,
audio_duration: int,
is_last: bool,
manager: ConnectionManager,
) -> None:
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
try:
# 每个分段任务使用独立 DB Session避免与主循环共享同一 AsyncSession 导致并发冲突。
async for db in get_async_db():
conversation = await db.get(Conversation, conversation_id)
user = await db.get(UserModel, user_id)
if not conversation:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "对话不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
if not user:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "用户不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
async with state.lock:
should_prime_state = (
state.consumed_index < 0
and not state.processed_indices
and not state.buffered_transcripts
)
if should_prime_state:
persisted_contiguous_index = await _get_persisted_contiguous_segment_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
if persisted_contiguous_index >= 0:
async with state.lock:
state.consumed_index = max(state.consumed_index, persisted_contiguous_index)
transcript_text = await asr_service.transcribe(audio_base64)
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text or "",
"audio_duration": audio_duration,
"voice_session_id": voice_session_id,
"segment_index": segment_index,
"is_last": is_last,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
if _is_transcribe_failure(transcript_text):
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {
"message": f"分段 {segment_index} 转写失败,请重试该片段",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
existing_segment = await _find_existing_segment_by_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
)
if existing_segment:
# 该分段已成功入库,视为重传:不重复入库、不重复触发 Agent。
async with state.lock:
state.processed_indices.add(segment_index)
logger.info(
"分段已存在,按幂等处理跳过: "
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
)
return
else:
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=transcript_text or "",
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
processed=False,
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(conversation)
await db.commit()
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
ready_segments: List[Tuple[int, str, Segment]] = []
async with state.lock:
state.processed_indices.add(segment_index)
state.buffered_transcripts[segment_index] = (transcript_text or "", segment)
next_index = state.consumed_index + 1
while next_index in state.buffered_transcripts:
text, seg = state.buffered_transcripts.pop(next_index)
ready_segments.append((next_index, text, seg))
state.consumed_index = next_index
next_index += 1
# 仅当前缀分段连续时才触发 Agent保证增量上下文顺序正确。
for _, ordered_text, ordered_segment in ready_segments:
await process_user_message(
conversation_id=conversation_id,
user_message=ordered_text,
conversation=conversation,
segment=ordered_segment,
db=db,
manager=manager,
user=user,
user_message_timestamp=ordered_segment.created_at or user_message_timestamp,
)
break
except Exception as e:
logger.error(
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
exc_info=True,
)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {
"message": f"分段处理失败: {str(e)}",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
finally:
async with state.lock:
state.pending_indices.discard(segment_index)
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str
):
"""
WebSocket 端点:处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
"""
# 从查询参数获取token
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
return
# 验证JWT令牌
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
return
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
return
# 验证用户是否存在
async for db in get_async_db():
user = await db.get(UserModel, user_id)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
return
await manager.connect(websocket, conversation_id)
try:
# 发送连接确认
await manager.send_message(conversation_id, {
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 从数据库获取对话信息
conversation = await db.get(Conversation, conversation_id)
if not conversation:
# 如果对话不存在,创建新对话
conversation = Conversation(
id=conversation_id,
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active"
)
db.add(conversation)
await db.commit()
else:
# 验证用户权限:只能访问自己的对话
if conversation.user_id != user_id:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception:
pass # 如果发送失败,直接关闭连接
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
return
# 首次连接时检查:若 Redis 已有历史(用户曾进入过此对话),不再发送开场白,避免重复/自问自答
history = await redis_service.get_conversation_history(conversation_id)
if not history:
# 空对话:发送开场白(资料收集或正式访谈)
missing_profile = _get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await manager.conversation_agent.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
)
import asyncio as _asyncio_greet
for i, text in enumerate(greetings):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": text, "index": i, "total": len(greetings)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(greetings) - 1:
await _asyncio_greet.sleep(0.5)
except Exception as e:
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
else:
# 资料已完整AI 先开口提问
try:
state = await get_or_create_state(user_id, db)
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
opening_messages = await manager.conversation_agent.generate_opening_message(
conversation_id=conversation_id,
memoir_state=state,
user_profile_context=user_profile_context,
)
import asyncio as _asyncio_open
for i, text in enumerate(opening_messages):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": text, "index": i, "total": len(opening_messages)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(opening_messages) - 1:
await _asyncio_open.sleep(0.5)
except Exception as e:
logger.error(f"发送空对话开场白失败: {e}", exc_info=True)
# 主循环:处理消息
while True:
try:
if websocket.application_state != WebSocketState.CONNECTED:
logger.info(f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}")
break
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.TEXT:
# 处理文本消息
text_message = message.get("data", {}).get("text", "")
if text_message:
# 校验对话轮数配额
from routers.quota import get_segment_count, check_can_send_message
seg_count = await get_segment_count(user_id, db)
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
# 保存段落到数据库
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=text_message,
processed=False
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(conversation)
await db.commit()
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
# Agent 生成回应
await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
conversation=conversation,
segment=segment,
db=db,
manager=manager,
user=user,
user_message_timestamp=segment.created_at or user_message_timestamp,
)
elif msg_type == MessageType.RECORDING_STARTED:
# 用户点击开始录音:启动 5s 定时器,到时发一次「我在认真听」
data = message.get("data", {})
voice_session_id = _normalize_voice_session_id(data.get("voice_session_id"))
segment_state = manager.get_or_create_segment_state(
conversation_id,
voice_session_id,
)
async with segment_state.lock:
if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done():
continue # 本会话已有待发任务,不重复
if segment_state.listening_feedback_sent:
continue
delayed_task = asyncio.create_task(
_delayed_listening_feedback(
conversation_id=conversation_id,
voice_session_id=voice_session_id,
manager=manager,
)
)
segment_state.listening_feedback_task = delayed_task
elif msg_type == MessageType.AUDIO_SEGMENT:
# 处理分段语音消息(长语音持续上传)
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
segment_index_raw = data.get("segment_index")
voice_session_id = _normalize_voice_session_id(
data.get("voice_session_id")
or _voice_session_id_from_client_segment_id(data.get("client_segment_id"))
)
is_last = bool(data.get("is_last", False))
audio_duration = int(data.get("duration", 0) or 0)
if not audio_base64:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
if segment_index_raw is None:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 segment_index"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
try:
segment_index = int(segment_index_raw)
except (TypeError, ValueError):
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "segment_index 必须为整数"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
if segment_index < 0:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "segment_index 不能为负数"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
# 校验对话轮数配额(分段也计入对话轮次)
from routers.quota import get_segment_count, check_can_send_message
seg_count = await get_segment_count(user_id, db)
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
segment_state = manager.get_or_create_segment_state(
conversation_id,
voice_session_id,
)
should_process = False
async with segment_state.lock:
already_seen = (
segment_index in segment_state.pending_indices
or segment_index in segment_state.processed_indices
or segment_index <= segment_state.consumed_index
)
if not already_seen:
segment_state.pending_indices.add(segment_index)
should_process = True
if not should_process:
logger.info(
"收到重复分段,跳过处理: "
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
)
continue
# 若本段是用户结束录音的最后一段,取消尚未发出的「我在认真听」,避免结束后再说
if is_last:
async with segment_state.lock:
t = segment_state.listening_feedback_task
segment_state.listening_feedback_task = None
if t is not None and not t.done():
t.cancel()
task = asyncio.create_task(
_process_audio_segment_async(
conversation_id=conversation_id,
user_id=user_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
audio_base64=audio_base64,
audio_duration=audio_duration,
is_last=is_last,
manager=manager,
)
)
manager.register_segment_task(conversation_id, voice_session_id, task)
elif msg_type == MessageType.AUDIO_MESSAGE:
# 处理完整音频消息(类似微信语音)
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
audio_duration = data.get("duration", 0)
if audio_base64:
# 校验对话轮数配额
from routers.quota import get_segment_count, check_can_send_message
seg_count = await get_segment_count(user_id, db)
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
logger.info(f"收到音频消息,时长: {audio_duration}s")
try:
# 1. ASR 转写
transcript_text = await asr_service.transcribe(audio_base64)
logger.info(f"ASR 转写结果: {transcript_text}")
# 2. 发送转写结果给客户端
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text,
"audio_duration": audio_duration
},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 3. 保存段落到数据库(包含转写文本和音频信息)
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=transcript_text,
audio_url=f"audio:{audio_duration}s", # 简化存储,标记为音频消息
processed=False
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(conversation)
await db.commit()
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
# 4. Agent 生成回应(基于转写文本)
if transcript_text and not transcript_text.startswith("转写失败"):
await process_user_message(
conversation_id=conversation_id,
user_message=transcript_text,
conversation=conversation,
segment=segment,
db=db,
manager=manager,
user=user,
user_message_timestamp=segment.created_at or user_message_timestamp,
)
else:
# 转写失败,发送错误消息
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "语音转写失败,请重试或使用文字输入"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as e:
logger.error(f"处理音频消息失败: {e}", exc_info=True)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"处理音频消息失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
elif msg_type == MessageType.TRANSCRIBE_ONLY:
# 仅转写:不落库、不触发 Agent只把识别结果返回给客户端
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
if not audio_base64:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
try:
transcript_text = await asr_service.transcribe(audio_base64)
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": transcript_text or ""},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as e:
logger.error(f"仅转写失败: {e}", exc_info=True)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"转写失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
elif msg_type == MessageType.END_CONVERSATION:
# 结束对话
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
await db.commit()
# 触发整理 Agent
await process_conversation_segments(conversation_id, db)
await manager.send_message(conversation_id, {
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
break
except RuntimeError as e:
# 检查是否是断开连接或未连接状态(如 accept 前/后连接被关闭)
error_msg = str(e)
if (
"disconnect" in error_msg.lower()
or "Cannot call \"receive\"" in error_msg
or "accept" in error_msg.lower() and "not connected" in error_msg.lower()
):
logger.info(f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}")
break
else:
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
break
except Exception as e:
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
await manager.disconnect(conversation_id)
except Exception as e:
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
await manager.disconnect(conversation_id)
finally:
# 确保清理连接
await manager.disconnect(conversation_id)
def _get_missing_profile_fields(user: UserModel) -> list:
"""检查用户缺失的资料字段"""
from agents.prompts.profile_prompts import get_missing_profile_fields
return get_missing_profile_fields(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
def _get_filled_profile_fields(user: UserModel) -> dict:
"""获取用户已有的资料字段(中文展示)"""
from agents.prompts.profile_prompts import PROFILE_FIELD_NAMES
filled = {}
if user.birth_year:
filled["birth_year"] = str(user.birth_year)
if user.birth_place:
filled["birth_place"] = user.birth_place
if user.grew_up_place:
filled["grew_up_place"] = user.grew_up_place
if user.occupation:
filled["occupation"] = user.occupation
return filled
async def _apply_extracted_profile(user: UserModel, extracted: dict, db: AsyncSession):
"""将提取到的资料信息保存到用户模型"""
changed = False
if "birth_year" in extracted and not user.birth_year:
user.birth_year = extracted["birth_year"]
changed = True
if "birth_place" in extracted and not user.birth_place:
user.birth_place = extracted["birth_place"]
changed = True
if "grew_up_place" in extracted and not user.grew_up_place:
user.grew_up_place = extracted["grew_up_place"]
changed = True
if "occupation" in extracted and not user.occupation:
user.occupation = extracted["occupation"]
changed = True
if changed:
await db.commit()
await db.refresh(user)
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager,
user: UserModel = None,
user_message_timestamp: Optional[datetime] = None,
) -> None:
"""
处理用户消息生成Agent回应异步版本
支持资料收集模式和正式访谈模式
"""
import asyncio as _asyncio
agent = manager.conversation_agent
# --- 资料收集模式 ---
if user:
missing = _get_missing_profile_fields(user)
if missing:
try:
extracted = await agent.extract_profile_from_message(
user_message, missing, conversation_id=conversation_id
)
if extracted:
await _apply_extracted_profile(user, extracted, db)
remaining = _get_missing_profile_fields(user)
filled = _get_filled_profile_fields(user)
is_from_voice = bool(segment.audio_url)
responses = await agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=remaining,
filled_fields=filled,
nickname=user.nickname or "",
is_from_voice=is_from_voice,
voice_session_id=_voice_session_id_from_audio_url(segment.audio_url),
user_message_timestamp=user_message_timestamp,
)
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
await db.commit()
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
return
except Exception as e:
logger.error(f"资料收集处理失败: {e}", exc_info=True)
# --- 正式访谈模式 ---
state = await get_or_create_state(conversation.user_id, db)
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 构建用户资料上下文
user_profile_context = ""
if user:
from agents.prompts.profile_prompts import format_user_profile_context
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
try:
is_from_voice = bool(segment.audio_url)
responses = await agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
user_profile_context=user_profile_context,
is_from_voice=is_from_voice,
voice_session_id=_voice_session_id_from_audio_url(segment.audio_url),
user_message_timestamp=user_message_timestamp,
)
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
await db.commit()
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
except Exception as e:
logger.error(f"处理用户消息失败: {e}", exc_info=True)
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
"""
处理对话段落,生成章节(对话结束时调用)
注意:大部分处理已通过 Celery 任务增量完成
这里立即提交所有待处理的段落到 Celery
Args:
conversation_id: 对话 ID
db: 数据库会话
"""
# 获取对话信息
conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
# 获取所有未处理的段落
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
# 没有未处理的段落,直接 flush 待处理任务
await manager.background_runner.flush_pending(conversation.user_id)
return
# 免费版仅允许 1 个章节整理,提交前校验
from database.models import User as UserModel
from routers.quota import get_chapter_count, check_can_submit_organize
user = await db.get(UserModel, conversation.user_id)
if user:
chapter_count = await get_chapter_count(user.id, db)
can_submit, _ = check_can_submit_organize(user.subscription_type, chapter_count)
if not can_submit:
logger.info(
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
)
await manager.background_runner.flush_pending(conversation.user_id)
return
# 将未处理的段落直接提交到 Celery不通过去抖
segment_ids = [seg.id for seg in segments]
try:
from tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(conversation.user_id, segment_ids)
logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}")
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
# 同时 flush 任何待处理的任务
await manager.background_runner.flush_pending(conversation.user_id)