Replace CreateRecTask polling with recording-file flash API, add TENCENT_APP_ID, remove server-side pydub slicing, and log ASR recognition text at INFO in development. Co-authored-by: Cursor <cursoragent@cursor.com>
879 lines
41 KiB
Python
879 lines
41 KiB
Python
"""
|
||
WebSocket 路由:实时对话通信
|
||
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
|
||
"""
|
||
|
||
import asyncio
|
||
import base64
|
||
from datetime import datetime, timezone
|
||
|
||
from fastapi import WebSocket, WebSocketDisconnect, status
|
||
from starlette.websockets import WebSocketState
|
||
|
||
from app.agents.chat.background_voice import infer_background_voice
|
||
from app.agents.chat.prompts_profile import format_user_profile_context
|
||
from app.core.agent_logging import log_asr_transcript_result
|
||
from app.core.config import settings
|
||
from app.core.db import AsyncSessionLocal
|
||
from app.core.dependencies import get_asr_provider
|
||
from app.core.logging import get_logger
|
||
from app.core.security import verify_token
|
||
from app.features.conversation.service import ConversationService
|
||
from app.features.conversation.ws.connection_manager import manager
|
||
from app.features.conversation.ws.message_types import MessageType
|
||
from app.features.conversation.ws.pipeline import (
|
||
_delayed_listening_feedback,
|
||
_voice_session_id_from_client_segment_id,
|
||
bump_tts_cancel_epoch,
|
||
chat_orchestrator,
|
||
cleanup_segment_states,
|
||
get_or_create_segment_state,
|
||
handle_tts_request_on_demand,
|
||
process_audio_segment,
|
||
process_conversation_segments,
|
||
process_persisted_user_segment_response,
|
||
register_segment_task,
|
||
register_user_response_task,
|
||
)
|
||
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
|
||
from app.features.conversation.ws.quota_guard import check_ws_quota
|
||
from app.features.conversation.ws.topic_chips_push import maybe_send_topic_chips_ws
|
||
from app.features.memoir.service import MemoirService
|
||
from app.features.quota.service import QuotaService
|
||
from app.features.user.service import UserService
|
||
from app.features.conversation.constants import chat
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _idle_hours_since(ts) -> float | None:
|
||
"""计算距 ts 的小时数;ts 为 None 或非 datetime 时返回 None。"""
|
||
if ts is None:
|
||
return None
|
||
if not isinstance(ts, datetime):
|
||
return None
|
||
if ts.tzinfo is None:
|
||
ts = ts.replace(tzinfo=timezone.utc)
|
||
delta = datetime.now(timezone.utc) - ts
|
||
return max(0.0, delta.total_seconds() / 3600.0)
|
||
|
||
|
||
async def websocket_endpoint(
|
||
websocket: WebSocket,
|
||
conversation_id: str,
|
||
):
|
||
"""
|
||
WebSocket 端点:处理实时对话
|
||
|
||
Args:
|
||
websocket: WebSocket 连接
|
||
conversation_id: 对话 ID
|
||
"""
|
||
token = websocket.query_params.get("token")
|
||
if not token:
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌"
|
||
)
|
||
return
|
||
|
||
payload = verify_token(token)
|
||
if not payload or payload.get("type") != "access":
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌"
|
||
)
|
||
return
|
||
|
||
user_id = payload.get("sub")
|
||
if not user_id:
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容"
|
||
)
|
||
return
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
user_service = UserService(db)
|
||
user = await user_service.get_by_id(user_id)
|
||
if not user:
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在"
|
||
)
|
||
return
|
||
|
||
await manager.connect(websocket, conversation_id)
|
||
logger.info(
|
||
"WebSocket 已连接 conversation_id={} user_id={}",
|
||
conversation_id,
|
||
user_id,
|
||
)
|
||
|
||
quota_service = QuotaService(db=db)
|
||
conversation_service = ConversationService(db=db, quota_service=quota_service)
|
||
memoir_service = MemoirService(db=db)
|
||
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.CONNECT,
|
||
"conversation_id": conversation_id,
|
||
"data": {"status": "connected"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
conversation, ws_conn_err = await conversation_service.ensure_ws_connection(
|
||
conversation_id, user_id
|
||
)
|
||
if ws_conn_err == "forbidden":
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "无权访问此对话"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception:
|
||
pass
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
|
||
)
|
||
return
|
||
if ws_conn_err == "deleted":
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "对话已删除"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception:
|
||
pass
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除"
|
||
)
|
||
return
|
||
|
||
# 冷启动对齐 conversation_stage 与 MemoirState.current_stage;
|
||
# 若对话行已有更靠前的人生阶段(STAGE_TO_ORDER 更大),不覆盖以免回退。
|
||
memoir_state = await memoir_service.get_or_create_memoir_state(user_id)
|
||
await conversation_service.align_conversation_stage_from_memoir(
|
||
conversation, memoir_state.current_stage or ""
|
||
)
|
||
await db.refresh(conversation)
|
||
|
||
history = await conversation_service.ensure_redis_history_from_db(
|
||
conversation_id
|
||
)
|
||
user_language = (
|
||
"en"
|
||
if str(getattr(user, "language_preference", "zh") or "zh").lower() == "en"
|
||
else "zh"
|
||
)
|
||
|
||
async def _stream_ai_only_messages(
|
||
texts: list[str], log_label: str
|
||
) -> None:
|
||
"""统一:把一组 AI 消息落库并按 [SPLIT] 分段下发。"""
|
||
if not texts:
|
||
return
|
||
ai_msg_id = await conversation_service.record_ai_only_turn(
|
||
conversation_id, texts
|
||
)
|
||
if not ai_msg_id:
|
||
return
|
||
total_n = len(texts)
|
||
for i, text in enumerate(texts):
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": text,
|
||
"index": i,
|
||
"total": total_n,
|
||
"assistant_message_id": ai_msg_id,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
if i < total_n - 1:
|
||
await asyncio.sleep(0.5)
|
||
logger.info(
|
||
"event=ws_auto_ai_sent label={} conversation_id={} segments={}",
|
||
log_label,
|
||
conversation_id,
|
||
total_n,
|
||
)
|
||
|
||
async def _maybe_send_topic_chips(reason: str) -> None:
|
||
await maybe_send_topic_chips_ws(
|
||
conversation_id,
|
||
user=user,
|
||
memoir_state=memoir_state,
|
||
reason=reason,
|
||
language=user_language,
|
||
)
|
||
|
||
if not history:
|
||
missing_profile = get_missing_profile_fields(user)
|
||
if missing_profile:
|
||
try:
|
||
greetings = await chat_orchestrator.generate_profile_greeting(
|
||
conversation_id=conversation_id,
|
||
missing_fields=missing_profile,
|
||
nickname=user.nickname or "",
|
||
language=user_language,
|
||
)
|
||
await _stream_ai_only_messages(
|
||
greetings, log_label="profile_greeting"
|
||
)
|
||
except Exception as e:
|
||
logger.exception("发送资料收集开场白失败: {}", e)
|
||
else:
|
||
try:
|
||
user_profile_context = format_user_profile_context(
|
||
birth_year=user.birth_year,
|
||
birth_place=user.birth_place,
|
||
grew_up_place=user.grew_up_place,
|
||
occupation=user.occupation,
|
||
language=user_language,
|
||
)
|
||
era_place = (user.grew_up_place or user.birth_place or "") or ""
|
||
opening_messages = (
|
||
await chat_orchestrator.generate_opening_message(
|
||
conversation_id=conversation_id,
|
||
memoir_state=memoir_state,
|
||
user_profile_context=user_profile_context,
|
||
background_voice=infer_background_voice(
|
||
user.occupation
|
||
),
|
||
occupation=user.occupation or "",
|
||
profile_birth_year=user.birth_year,
|
||
profile_era_place=era_place,
|
||
language=user_language,
|
||
)
|
||
)
|
||
await _stream_ai_only_messages(
|
||
opening_messages, log_label="opening"
|
||
)
|
||
await _maybe_send_topic_chips(reason="opening")
|
||
except Exception as e:
|
||
logger.exception("发送空对话开场白失败: {}", e)
|
||
else:
|
||
# 历史非空:判断是否需要回访问候(距上次消息超过阈值)
|
||
idle_hours = _idle_hours_since(conversation.last_message_at)
|
||
threshold = float(chat.re_greeting_idle_hours)
|
||
if (
|
||
chat.re_greeting_enabled
|
||
and not get_missing_profile_fields(user)
|
||
and idle_hours is not None
|
||
and idle_hours >= threshold
|
||
):
|
||
try:
|
||
user_profile_context = format_user_profile_context(
|
||
birth_year=user.birth_year,
|
||
birth_place=user.birth_place,
|
||
grew_up_place=user.grew_up_place,
|
||
occupation=user.occupation,
|
||
language=user_language,
|
||
)
|
||
era_place = (user.grew_up_place or user.birth_place or "") or ""
|
||
re_greetings = (
|
||
await chat_orchestrator.generate_re_greeting_message(
|
||
conversation_id=conversation_id,
|
||
memoir_state=memoir_state,
|
||
idle_hours=idle_hours,
|
||
user_profile_context=user_profile_context,
|
||
background_voice=infer_background_voice(
|
||
user.occupation
|
||
),
|
||
occupation=user.occupation or "",
|
||
profile_birth_year=user.birth_year,
|
||
profile_era_place=era_place,
|
||
language=user_language,
|
||
)
|
||
)
|
||
await _stream_ai_only_messages(
|
||
re_greetings, log_label="re_greeting"
|
||
)
|
||
logger.info(
|
||
"event=ws_re_greeting_emitted conversation_id={} "
|
||
"idle_hours={:.2f} threshold={:.2f}",
|
||
conversation_id,
|
||
idle_hours,
|
||
threshold,
|
||
)
|
||
await _maybe_send_topic_chips(reason="re_greeting")
|
||
except Exception as e:
|
||
logger.exception("发送回访问候失败: {}", e)
|
||
else:
|
||
# 不触发回访问候时,仍可下发 chips 以减少冷启动门槛
|
||
await _maybe_send_topic_chips(reason="resume")
|
||
|
||
while True:
|
||
try:
|
||
if websocket.application_state != WebSocketState.CONNECTED:
|
||
logger.debug(
|
||
"WebSocket 已非连接状态,退出循环: conversation_id={}",
|
||
conversation_id,
|
||
)
|
||
break
|
||
message = await websocket.receive_json()
|
||
msg_type = message.get("type")
|
||
if msg_type == MessageType.AUDIO_SEGMENT:
|
||
_d = message.get("data") or {}
|
||
logger.info(
|
||
"WebSocket 收到消息 type={} conversation_id={} "
|
||
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
|
||
msg_type,
|
||
conversation_id,
|
||
_d.get("segment_index"),
|
||
bool(_d.get("is_last")),
|
||
int(_d.get("duration") or 0),
|
||
len(_d.get("audio_base64") or ""),
|
||
)
|
||
elif msg_type is not None:
|
||
logger.info(
|
||
"WebSocket 收到消息 type={} conversation_id={}",
|
||
msg_type,
|
||
conversation_id,
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"WebSocket 收到缺少 type 的 JSON conversation_id={}",
|
||
conversation_id,
|
||
)
|
||
|
||
if msg_type == MessageType.TEXT:
|
||
data = message.get("data") or {}
|
||
text_message = data.get("text", "")
|
||
tts_this_turn = bool(data.get("tts_this_turn"))
|
||
|
||
if text_message:
|
||
can_send, quota_msg = await check_ws_quota(
|
||
quota_service, user_id, user.subscription_type
|
||
)
|
||
if not can_send:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": quota_msg,
|
||
"code": "QUOTA_EXCEEDED",
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
segment = await conversation_service.create_user_segment(
|
||
conversation,
|
||
user_id,
|
||
text_message,
|
||
)
|
||
|
||
task = asyncio.create_task(
|
||
process_persisted_user_segment_response(
|
||
conversation_id=conversation_id,
|
||
user_id=user_id,
|
||
segment_id=segment.id,
|
||
tts_this_turn=tts_this_turn,
|
||
)
|
||
)
|
||
register_user_response_task(conversation_id, task)
|
||
|
||
elif msg_type == MessageType.RECORDING_STARTED:
|
||
data = message.get("data", {})
|
||
raw_vs = data.get("voice_session_id")
|
||
if not raw_vs or not str(raw_vs).strip():
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 voice_session_id"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
voice_session_id = str(raw_vs).strip()
|
||
segment_state = get_or_create_segment_state(
|
||
conversation_id,
|
||
voice_session_id,
|
||
)
|
||
async with segment_state.lock:
|
||
if (
|
||
segment_state.listening_feedback_task is not None
|
||
and not segment_state.listening_feedback_task.done()
|
||
):
|
||
continue
|
||
if segment_state.listening_feedback_sent:
|
||
continue
|
||
delayed_task = asyncio.create_task(
|
||
_delayed_listening_feedback(
|
||
conversation_id=conversation_id,
|
||
voice_session_id=voice_session_id,
|
||
)
|
||
)
|
||
segment_state.listening_feedback_task = delayed_task
|
||
|
||
elif msg_type == MessageType.AUDIO_SEGMENT:
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
segment_index_raw = data.get("segment_index")
|
||
resolved_vs = data.get("voice_session_id") or (
|
||
_voice_session_id_from_client_segment_id(
|
||
data.get("client_segment_id")
|
||
)
|
||
)
|
||
if not resolved_vs or not str(resolved_vs).strip():
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": "缺少 voice_session_id 或有效的 client_segment_id"
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
voice_session_id = str(resolved_vs).strip()
|
||
is_last = bool(data.get("is_last", False))
|
||
audio_duration = int(data.get("duration", 0) or 0)
|
||
tts_this_turn_segment = bool(data.get("tts_this_turn"))
|
||
|
||
if not audio_base64:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 audio_base64"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
if segment_index_raw is None:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 segment_index"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
try:
|
||
segment_index = int(segment_index_raw)
|
||
except (TypeError, ValueError):
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "segment_index 必须为整数"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
if segment_index < 0:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "segment_index 不能为负数"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
can_send, quota_msg = await check_ws_quota(
|
||
quota_service, user_id, user.subscription_type
|
||
)
|
||
if not can_send:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": quota_msg,
|
||
"code": "QUOTA_EXCEEDED",
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
segment_state = get_or_create_segment_state(
|
||
conversation_id,
|
||
voice_session_id,
|
||
)
|
||
should_process = False
|
||
async with segment_state.lock:
|
||
already_seen = (
|
||
segment_index in segment_state.pending_indices
|
||
or segment_index in segment_state.processed_indices
|
||
or segment_index <= segment_state.consumed_index
|
||
)
|
||
if not already_seen:
|
||
segment_state.pending_indices.add(segment_index)
|
||
should_process = True
|
||
|
||
if not should_process:
|
||
logger.debug(
|
||
"收到重复分段,跳过: conversation_id={} voice_session_id={} "
|
||
"segment_index={} audio_b64_len={} duration={}",
|
||
conversation_id,
|
||
voice_session_id,
|
||
segment_index,
|
||
len(audio_base64 or ""),
|
||
audio_duration,
|
||
)
|
||
continue
|
||
|
||
if is_last:
|
||
async with segment_state.lock:
|
||
t = segment_state.listening_feedback_task
|
||
segment_state.listening_feedback_task = None
|
||
if t is not None and not t.done():
|
||
t.cancel()
|
||
|
||
task = asyncio.create_task(
|
||
process_audio_segment(
|
||
conversation_id=conversation_id,
|
||
user_id=user_id,
|
||
voice_session_id=voice_session_id,
|
||
segment_index=segment_index,
|
||
audio_base64=audio_base64,
|
||
audio_duration=audio_duration,
|
||
is_last=is_last,
|
||
tts_this_turn=tts_this_turn_segment,
|
||
)
|
||
)
|
||
register_segment_task(conversation_id, voice_session_id, task)
|
||
|
||
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
audio_duration = data.get("duration", 0)
|
||
tts_this_turn = bool(data.get("tts_this_turn"))
|
||
|
||
if audio_base64:
|
||
can_send, quota_msg = await check_ws_quota(
|
||
quota_service, user_id, user.subscription_type
|
||
)
|
||
if not can_send:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": quota_msg,
|
||
"code": "QUOTA_EXCEEDED",
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
logger.debug(
|
||
"收到音频消息: conversation_id={} duration_s={}",
|
||
conversation_id,
|
||
audio_duration,
|
||
)
|
||
|
||
try:
|
||
asr = get_asr_provider()
|
||
audio_bytes = base64.b64decode(audio_base64)
|
||
asr_text = await asr.transcribe(audio_bytes, "m4a")
|
||
log_asr_transcript_result(
|
||
logger,
|
||
text=asr_text or "",
|
||
conversation_id=conversation_id,
|
||
duration_s=audio_duration,
|
||
source="audio_message",
|
||
)
|
||
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.TRANSCRIPT,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": asr_text,
|
||
"audio_duration": audio_duration,
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
|
||
try:
|
||
ads = int(audio_duration)
|
||
except (TypeError, ValueError):
|
||
ads = 0
|
||
segment = (
|
||
await conversation_service.create_user_segment(
|
||
conversation,
|
||
user_id,
|
||
asr_text,
|
||
audio_url=f"audio:{audio_duration}s",
|
||
audio_duration_seconds=ads if ads > 0 else None,
|
||
)
|
||
)
|
||
|
||
if asr_text and not asr_text.startswith("转写失败"):
|
||
task = asyncio.create_task(
|
||
process_persisted_user_segment_response(
|
||
conversation_id=conversation_id,
|
||
user_id=user_id,
|
||
segment_id=segment.id,
|
||
tts_this_turn=tts_this_turn,
|
||
)
|
||
)
|
||
register_user_response_task(conversation_id, task)
|
||
else:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": "语音转写失败,请重试或使用文字输入"
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.exception("处理音频消息失败: {}", e)
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": "语音处理失败,请重试或使用文字输入"
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
|
||
elif msg_type == MessageType.TRANSCRIBE_ONLY:
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
if not audio_base64:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 audio_base64"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
try:
|
||
asr = get_asr_provider()
|
||
audio_bytes = base64.b64decode(audio_base64)
|
||
asr_text = await asr.transcribe(audio_bytes, "m4a")
|
||
log_asr_transcript_result(
|
||
logger,
|
||
text=asr_text or "",
|
||
conversation_id=conversation_id,
|
||
source="transcribe_only",
|
||
)
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.TRANSCRIPT,
|
||
"conversation_id": conversation_id,
|
||
"data": {"text": asr_text or ""},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception as e:
|
||
logger.exception("仅转写失败: {}", e)
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "语音转写失败,请重试"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
elif msg_type == MessageType.TTS_CANCEL:
|
||
bump_tts_cancel_epoch(conversation_id)
|
||
|
||
elif msg_type == MessageType.TTS_REQUEST:
|
||
data = message.get("data") or {}
|
||
aid = data.get("assistant_message_id") or data.get(
|
||
"assistantMessageId"
|
||
)
|
||
if not aid or not str(aid).strip():
|
||
logger.warning(
|
||
"ws.TTS_REQUEST 缺少 assistant_message_id "
|
||
"conversation_id={} user_id={}",
|
||
conversation_id,
|
||
user_id,
|
||
)
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少助手消息 id"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
try:
|
||
seg_idx = int(
|
||
data.get("segment_index", data.get("segmentIndex", 0))
|
||
)
|
||
except (TypeError, ValueError):
|
||
seg_idx = 0
|
||
st = data.get("segment_text") or data.get("segmentText")
|
||
st_val: str | None
|
||
if st is None:
|
||
st_val = None
|
||
else:
|
||
st_val = str(st).strip() or None
|
||
ok, err_msg = await handle_tts_request_on_demand(
|
||
conversation_id=conversation_id,
|
||
user_id=user_id,
|
||
assistant_message_id=str(aid).strip(),
|
||
segment_index=seg_idx,
|
||
segment_text=st_val,
|
||
db=db,
|
||
)
|
||
if not ok:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": err_msg or "朗读请求失败"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
elif msg_type == MessageType.END_CONVERSATION:
|
||
await conversation_service.end(conversation_id, user_id)
|
||
|
||
await process_conversation_segments(
|
||
conversation_id, db, quota_service
|
||
)
|
||
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.END_CONVERSATION,
|
||
"conversation_id": conversation_id,
|
||
"data": {"status": "ended"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
break
|
||
|
||
elif msg_type == MessageType.PING:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.PONG,
|
||
"conversation_id": conversation_id,
|
||
"data": {},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
else:
|
||
if msg_type is not None:
|
||
logger.warning(
|
||
"WebSocket 未识别的消息 type={} conversation_id={}",
|
||
msg_type,
|
||
conversation_id,
|
||
)
|
||
|
||
except RuntimeError as e:
|
||
error_msg = str(e)
|
||
if (
|
||
"disconnect" in error_msg.lower()
|
||
or 'Cannot call "receive"' in error_msg
|
||
or "accept" in error_msg.lower()
|
||
and "not connected" in error_msg.lower()
|
||
):
|
||
logger.debug(
|
||
"WebSocket 连接已断开或未就绪: conversation_id={} error={}",
|
||
conversation_id,
|
||
error_msg,
|
||
)
|
||
break
|
||
else:
|
||
logger.exception("处理消息时发生 RuntimeError: {}", e)
|
||
if conversation_id in manager.active_connections:
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "处理失败,请重试"},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
except Exception as send_error:
|
||
logger.warning("发送错误消息失败: {}", send_error)
|
||
break
|
||
except WebSocketDisconnect as disc:
|
||
logger.info(
|
||
"WebSocket 断开连接(收消息循环): conversation_id={} code={}",
|
||
conversation_id,
|
||
getattr(disc, "code", None),
|
||
)
|
||
break
|
||
except Exception as e:
|
||
logger.exception("处理消息时发生错误: {}", e)
|
||
if conversation_id in manager.active_connections:
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "处理失败,请重试"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception as send_error:
|
||
logger.warning("发送错误消息失败: {}", send_error)
|
||
break
|
||
|
||
except WebSocketDisconnect as disc:
|
||
logger.info(
|
||
"WebSocket 断开连接: conversation_id={} code={}",
|
||
conversation_id,
|
||
getattr(disc, "code", None),
|
||
)
|
||
await manager.disconnect(conversation_id)
|
||
cleanup_segment_states(conversation_id)
|
||
except Exception as e:
|
||
logger.exception("WebSocket 端点发生错误: {}", e)
|
||
await manager.disconnect(conversation_id)
|
||
cleanup_segment_states(conversation_id)
|
||
finally:
|
||
await manager.disconnect(conversation_id)
|
||
cleanup_segment_states(conversation_id)
|