Two complementary changes to reduce conversation cold-start friction: A. Returning-user re-greeting (backend) - When WS reconnects to a non-empty conversation and last_message_at is older than chat_re_greeting_idle_hours (default 6h), the agent emits a warm continuation message that references prior history instead of staying silent. - Self-debouncing: the AI message updates last_message_at, so reconnects within the window will not re-trigger. - Skipped while profile collection is still pending. D. Topic suggestion chips (backend + Expo) - New WS message type topic_suggestions carries 3-4 quick-start chips derived from the current memoir stage's empty slots (deterministic, no extra LLM cost). Sent alongside opening / re-greeting / resume. - Expo chat screen renders a horizontally-scrollable chip row above the input bar; tapping a chip sends the chip's text as a user message and clears the row. Sending any text/voice also clears the chips.
865 lines
40 KiB
Python
865 lines
40 KiB
Python
"""
|
||
WebSocket 路由:实时对话通信
|
||
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
|
||
"""
|
||
|
||
import asyncio
|
||
import base64
|
||
from datetime import datetime, timezone
|
||
|
||
from fastapi import WebSocket, WebSocketDisconnect, status
|
||
from starlette.websockets import WebSocketState
|
||
|
||
from app.agents.chat.background_voice import infer_background_voice
|
||
from app.agents.chat.prompts_conversation import build_topic_chips
|
||
from app.agents.chat.prompts_profile import format_user_profile_context
|
||
from app.agents.stage_constants import STAGE_TO_ORDER
|
||
from app.agents.state_schema import (
|
||
interview_control_state,
|
||
narrative_coverage_state,
|
||
)
|
||
from app.core.config import settings
|
||
from app.core.db import AsyncSessionLocal
|
||
from app.core.dependencies import get_asr_provider
|
||
from app.core.logging import get_logger
|
||
from app.core.security import verify_token
|
||
from app.features.conversation.history_store import ConversationHistoryStore
|
||
from app.features.conversation.service import ConversationService
|
||
from app.features.conversation.ws.connection_manager import manager
|
||
from app.features.conversation.ws.message_types import MessageType
|
||
from app.features.conversation.ws.pipeline import (
|
||
_delayed_listening_feedback,
|
||
_voice_session_id_from_client_segment_id,
|
||
bump_tts_cancel_epoch,
|
||
chat_orchestrator,
|
||
cleanup_segment_states,
|
||
get_or_create_segment_state,
|
||
memoir_ingest_scheduler,
|
||
process_audio_segment,
|
||
process_conversation_segments,
|
||
process_user_message,
|
||
register_segment_task,
|
||
)
|
||
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
|
||
from app.features.conversation.ws.quota_guard import check_ws_quota
|
||
from app.features.memoir.state_service import get_or_create_state
|
||
from app.features.quota.service import QuotaService
|
||
from app.features.user.models import User
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _idle_hours_since(ts) -> float | None:
|
||
"""计算距 ts 的小时数;ts 为 None 或非 datetime 时返回 None。"""
|
||
if ts is None:
|
||
return None
|
||
if not isinstance(ts, datetime):
|
||
return None
|
||
if ts.tzinfo is None:
|
||
ts = ts.replace(tzinfo=timezone.utc)
|
||
delta = datetime.now(timezone.utc) - ts
|
||
return max(0.0, delta.total_seconds() / 3600.0)
|
||
|
||
|
||
async def websocket_endpoint(
|
||
websocket: WebSocket,
|
||
conversation_id: str,
|
||
):
|
||
"""
|
||
WebSocket 端点:处理实时对话
|
||
|
||
Args:
|
||
websocket: WebSocket 连接
|
||
conversation_id: 对话 ID
|
||
"""
|
||
token = websocket.query_params.get("token")
|
||
if not token:
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌"
|
||
)
|
||
return
|
||
|
||
payload = verify_token(token)
|
||
if not payload or payload.get("type") != "access":
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌"
|
||
)
|
||
return
|
||
|
||
user_id = payload.get("sub")
|
||
if not user_id:
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容"
|
||
)
|
||
return
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
user = await db.get(User, user_id)
|
||
if not user:
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在"
|
||
)
|
||
return
|
||
|
||
await manager.connect(websocket, conversation_id)
|
||
logger.info(
|
||
"WebSocket 已连接 conversation_id={} user_id={}",
|
||
conversation_id,
|
||
user_id,
|
||
)
|
||
|
||
quota_service = QuotaService(db=db)
|
||
conversation_service = ConversationService(db=db, quota_service=quota_service)
|
||
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.CONNECT,
|
||
"conversation_id": conversation_id,
|
||
"data": {"status": "connected"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
conversation, ws_conn_err = await conversation_service.ensure_ws_connection(
|
||
conversation_id, user_id
|
||
)
|
||
if ws_conn_err == "forbidden":
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "无权访问此对话"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception:
|
||
pass
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
|
||
)
|
||
return
|
||
if ws_conn_err == "deleted":
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "对话已删除"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception:
|
||
pass
|
||
await websocket.close(
|
||
code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除"
|
||
)
|
||
return
|
||
|
||
# 冷启动对齐 conversation_stage 与 MemoirState.current_stage;
|
||
# 若对话行已有更靠前的人生阶段(STAGE_TO_ORDER 更大),不覆盖以免回退。
|
||
memoir_state = await get_or_create_state(user_id, db)
|
||
ms = (memoir_state.current_stage or "").strip()
|
||
cs = (conversation.conversation_stage or "").strip()
|
||
if ms:
|
||
if not cs:
|
||
conversation.conversation_stage = ms
|
||
elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1):
|
||
conversation.conversation_stage = ms
|
||
await db.commit()
|
||
await db.refresh(conversation)
|
||
|
||
history = await conversation_service.ensure_redis_history_from_db(
|
||
conversation_id
|
||
)
|
||
|
||
async def _stream_ai_only_messages(
|
||
texts: list[str], log_label: str
|
||
) -> None:
|
||
"""统一:把一组 AI 消息落库并按 [SPLIT] 分段下发。"""
|
||
if not texts:
|
||
return
|
||
ai_msg_id = await ConversationHistoryStore(db).record_ai_only_turn(
|
||
conversation_id, texts
|
||
)
|
||
if not ai_msg_id:
|
||
return
|
||
total_n = len(texts)
|
||
for i, text in enumerate(texts):
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": text,
|
||
"index": i,
|
||
"total": total_n,
|
||
"assistant_message_id": ai_msg_id,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
if i < total_n - 1:
|
||
await asyncio.sleep(0.5)
|
||
logger.info(
|
||
"event=ws_auto_ai_sent label={} conversation_id={} segments={}",
|
||
log_label,
|
||
conversation_id,
|
||
total_n,
|
||
)
|
||
|
||
async def _maybe_send_topic_chips(reason: str) -> None:
|
||
"""根据当前阶段空 slot 生成 quick-start 话题 chips;失败静默。"""
|
||
if not settings.chat_topic_chips_enabled:
|
||
return
|
||
# 资料未齐时不送 chips:profile 收集走另一条流程,chips 反而噪音
|
||
if get_missing_profile_fields(user):
|
||
return
|
||
try:
|
||
narrative_state = narrative_coverage_state(memoir_state)
|
||
control_state = interview_control_state(memoir_state)
|
||
empty_slots = control_state.prompt_empty_slots_for_stage(
|
||
narrative_state, memoir_state.current_stage
|
||
)
|
||
chips = build_topic_chips(
|
||
memoir_state.current_stage,
|
||
empty_slots,
|
||
max_chips=settings.chat_topic_chips_max,
|
||
)
|
||
if not chips:
|
||
return
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.TOPIC_SUGGESTIONS,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"reason": reason,
|
||
"stage": memoir_state.current_stage,
|
||
"suggestions": chips,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
logger.info(
|
||
"event=ws_topic_chips_sent reason={} conversation_id={} "
|
||
"stage={} count={}",
|
||
reason,
|
||
conversation_id,
|
||
memoir_state.current_stage,
|
||
len(chips),
|
||
)
|
||
except Exception as e:
|
||
logger.warning("发送话题 chips 失败: {}", e)
|
||
|
||
if not history:
|
||
missing_profile = get_missing_profile_fields(user)
|
||
if missing_profile:
|
||
try:
|
||
greetings = await chat_orchestrator.generate_profile_greeting(
|
||
conversation_id=conversation_id,
|
||
missing_fields=missing_profile,
|
||
nickname=user.nickname or "",
|
||
)
|
||
await _stream_ai_only_messages(
|
||
greetings, log_label="profile_greeting"
|
||
)
|
||
except Exception as e:
|
||
logger.exception("发送资料收集开场白失败: {}", e)
|
||
else:
|
||
try:
|
||
user_profile_context = format_user_profile_context(
|
||
birth_year=user.birth_year,
|
||
birth_place=user.birth_place,
|
||
grew_up_place=user.grew_up_place,
|
||
occupation=user.occupation,
|
||
)
|
||
era_place = (user.grew_up_place or user.birth_place or "") or ""
|
||
opening_messages = (
|
||
await chat_orchestrator.generate_opening_message(
|
||
conversation_id=conversation_id,
|
||
memoir_state=memoir_state,
|
||
user_profile_context=user_profile_context,
|
||
background_voice=infer_background_voice(
|
||
user.occupation
|
||
),
|
||
occupation=user.occupation or "",
|
||
profile_birth_year=user.birth_year,
|
||
profile_era_place=era_place,
|
||
)
|
||
)
|
||
await _stream_ai_only_messages(
|
||
opening_messages, log_label="opening"
|
||
)
|
||
await _maybe_send_topic_chips(reason="opening")
|
||
except Exception as e:
|
||
logger.exception("发送空对话开场白失败: {}", e)
|
||
else:
|
||
# 历史非空:判断是否需要回访问候(距上次消息超过阈值)
|
||
idle_hours = _idle_hours_since(conversation.last_message_at)
|
||
threshold = float(settings.chat_re_greeting_idle_hours)
|
||
if (
|
||
settings.chat_re_greeting_enabled
|
||
and not get_missing_profile_fields(user)
|
||
and idle_hours is not None
|
||
and idle_hours >= threshold
|
||
):
|
||
try:
|
||
user_profile_context = format_user_profile_context(
|
||
birth_year=user.birth_year,
|
||
birth_place=user.birth_place,
|
||
grew_up_place=user.grew_up_place,
|
||
occupation=user.occupation,
|
||
)
|
||
era_place = (user.grew_up_place or user.birth_place or "") or ""
|
||
re_greetings = (
|
||
await chat_orchestrator.generate_re_greeting_message(
|
||
conversation_id=conversation_id,
|
||
memoir_state=memoir_state,
|
||
idle_hours=idle_hours,
|
||
user_profile_context=user_profile_context,
|
||
background_voice=infer_background_voice(
|
||
user.occupation
|
||
),
|
||
occupation=user.occupation or "",
|
||
profile_birth_year=user.birth_year,
|
||
profile_era_place=era_place,
|
||
)
|
||
)
|
||
await _stream_ai_only_messages(
|
||
re_greetings, log_label="re_greeting"
|
||
)
|
||
logger.info(
|
||
"event=ws_re_greeting_emitted conversation_id={} "
|
||
"idle_hours={:.2f} threshold={:.2f}",
|
||
conversation_id,
|
||
idle_hours,
|
||
threshold,
|
||
)
|
||
await _maybe_send_topic_chips(reason="re_greeting")
|
||
except Exception as e:
|
||
logger.exception("发送回访问候失败: {}", e)
|
||
else:
|
||
# 不触发回访问候时,仍可下发 chips 以减少冷启动门槛
|
||
await _maybe_send_topic_chips(reason="resume")
|
||
|
||
while True:
|
||
try:
|
||
if websocket.application_state != WebSocketState.CONNECTED:
|
||
logger.debug(
|
||
"WebSocket 已非连接状态,退出循环: conversation_id={}",
|
||
conversation_id,
|
||
)
|
||
break
|
||
message = await websocket.receive_json()
|
||
msg_type = message.get("type")
|
||
if msg_type == MessageType.AUDIO_SEGMENT:
|
||
_d = message.get("data") or {}
|
||
logger.info(
|
||
"WebSocket 收到消息 type={} conversation_id={} "
|
||
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
|
||
msg_type,
|
||
conversation_id,
|
||
_d.get("segment_index"),
|
||
bool(_d.get("is_last")),
|
||
int(_d.get("duration") or 0),
|
||
len(_d.get("audio_base64") or ""),
|
||
)
|
||
elif msg_type is not None:
|
||
logger.info(
|
||
"WebSocket 收到消息 type={} conversation_id={}",
|
||
msg_type,
|
||
conversation_id,
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"WebSocket 收到缺少 type 的 JSON conversation_id={}",
|
||
conversation_id,
|
||
)
|
||
|
||
if msg_type == MessageType.TEXT:
|
||
text_message = message.get("data", {}).get("text", "")
|
||
|
||
if text_message:
|
||
can_send, quota_msg = await check_ws_quota(
|
||
quota_service, user_id, user.subscription_type
|
||
)
|
||
if not can_send:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": quota_msg,
|
||
"code": "QUOTA_EXCEEDED",
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
segment = await conversation_service.create_user_segment(
|
||
conversation,
|
||
user_id,
|
||
text_message,
|
||
)
|
||
user_message_timestamp = conversation.last_message_at
|
||
await memoir_ingest_scheduler.queue_segment(
|
||
conversation.user_id,
|
||
segment.id,
|
||
text_char_count=len(text_message.strip()),
|
||
)
|
||
|
||
await process_user_message(
|
||
conversation_id=conversation_id,
|
||
user_message=text_message,
|
||
conversation=conversation,
|
||
segment=segment,
|
||
db=db,
|
||
user=user,
|
||
user_message_timestamp=segment.created_at
|
||
or user_message_timestamp,
|
||
)
|
||
|
||
elif msg_type == MessageType.RECORDING_STARTED:
|
||
data = message.get("data", {})
|
||
raw_vs = data.get("voice_session_id")
|
||
if not raw_vs or not str(raw_vs).strip():
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 voice_session_id"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
voice_session_id = str(raw_vs).strip()
|
||
segment_state = get_or_create_segment_state(
|
||
conversation_id,
|
||
voice_session_id,
|
||
)
|
||
async with segment_state.lock:
|
||
if (
|
||
segment_state.listening_feedback_task is not None
|
||
and not segment_state.listening_feedback_task.done()
|
||
):
|
||
continue
|
||
if segment_state.listening_feedback_sent:
|
||
continue
|
||
delayed_task = asyncio.create_task(
|
||
_delayed_listening_feedback(
|
||
conversation_id=conversation_id,
|
||
voice_session_id=voice_session_id,
|
||
)
|
||
)
|
||
segment_state.listening_feedback_task = delayed_task
|
||
|
||
elif msg_type == MessageType.AUDIO_SEGMENT:
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
segment_index_raw = data.get("segment_index")
|
||
resolved_vs = data.get("voice_session_id") or (
|
||
_voice_session_id_from_client_segment_id(
|
||
data.get("client_segment_id")
|
||
)
|
||
)
|
||
if not resolved_vs or not str(resolved_vs).strip():
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": "缺少 voice_session_id 或有效的 client_segment_id"
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
voice_session_id = str(resolved_vs).strip()
|
||
is_last = bool(data.get("is_last", False))
|
||
audio_duration = int(data.get("duration", 0) or 0)
|
||
|
||
if not audio_base64:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 audio_base64"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
if segment_index_raw is None:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 segment_index"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
try:
|
||
segment_index = int(segment_index_raw)
|
||
except (TypeError, ValueError):
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "segment_index 必须为整数"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
if segment_index < 0:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "segment_index 不能为负数"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
can_send, quota_msg = await check_ws_quota(
|
||
quota_service, user_id, user.subscription_type
|
||
)
|
||
if not can_send:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": quota_msg,
|
||
"code": "QUOTA_EXCEEDED",
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
segment_state = get_or_create_segment_state(
|
||
conversation_id,
|
||
voice_session_id,
|
||
)
|
||
should_process = False
|
||
async with segment_state.lock:
|
||
already_seen = (
|
||
segment_index in segment_state.pending_indices
|
||
or segment_index in segment_state.processed_indices
|
||
or segment_index <= segment_state.consumed_index
|
||
)
|
||
if not already_seen:
|
||
segment_state.pending_indices.add(segment_index)
|
||
should_process = True
|
||
|
||
if not should_process:
|
||
logger.debug(
|
||
"收到重复分段,跳过: conversation_id={} voice_session_id={} "
|
||
"segment_index={} audio_b64_len={} duration={}",
|
||
conversation_id,
|
||
voice_session_id,
|
||
segment_index,
|
||
len(audio_base64 or ""),
|
||
audio_duration,
|
||
)
|
||
continue
|
||
|
||
if is_last:
|
||
async with segment_state.lock:
|
||
t = segment_state.listening_feedback_task
|
||
segment_state.listening_feedback_task = None
|
||
if t is not None and not t.done():
|
||
t.cancel()
|
||
|
||
task = asyncio.create_task(
|
||
process_audio_segment(
|
||
conversation_id=conversation_id,
|
||
user_id=user_id,
|
||
voice_session_id=voice_session_id,
|
||
segment_index=segment_index,
|
||
audio_base64=audio_base64,
|
||
audio_duration=audio_duration,
|
||
is_last=is_last,
|
||
)
|
||
)
|
||
register_segment_task(conversation_id, voice_session_id, task)
|
||
|
||
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
audio_duration = data.get("duration", 0)
|
||
|
||
if audio_base64:
|
||
can_send, quota_msg = await check_ws_quota(
|
||
quota_service, user_id, user.subscription_type
|
||
)
|
||
if not can_send:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": quota_msg,
|
||
"code": "QUOTA_EXCEEDED",
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
|
||
logger.debug(
|
||
"收到音频消息: conversation_id={} duration_s={}",
|
||
conversation_id,
|
||
audio_duration,
|
||
)
|
||
|
||
try:
|
||
asr = get_asr_provider()
|
||
audio_bytes = base64.b64decode(audio_base64)
|
||
asr_text = await asr.transcribe(audio_bytes, "m4a")
|
||
logger.debug(
|
||
"ASR 转写完成: conversation_id={} chars={}",
|
||
conversation_id,
|
||
len(asr_text or ""),
|
||
)
|
||
logger.debug(
|
||
"ASR 转写全文: conversation_id={} text={}",
|
||
conversation_id,
|
||
asr_text,
|
||
)
|
||
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.TRANSCRIPT,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": asr_text,
|
||
"audio_duration": audio_duration,
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
|
||
try:
|
||
ads = int(audio_duration)
|
||
except (TypeError, ValueError):
|
||
ads = 0
|
||
segment = (
|
||
await conversation_service.create_user_segment(
|
||
conversation,
|
||
user_id,
|
||
asr_text,
|
||
audio_url=f"audio:{audio_duration}s",
|
||
audio_duration_seconds=ads if ads > 0 else None,
|
||
)
|
||
)
|
||
user_message_timestamp = conversation.last_message_at
|
||
await memoir_ingest_scheduler.queue_segment(
|
||
conversation.user_id,
|
||
segment.id,
|
||
text_char_count=len((asr_text or "").strip()),
|
||
)
|
||
|
||
if asr_text and not asr_text.startswith("转写失败"):
|
||
await process_user_message(
|
||
conversation_id=conversation_id,
|
||
user_message=asr_text,
|
||
conversation=conversation,
|
||
segment=segment,
|
||
db=db,
|
||
user=user,
|
||
user_message_timestamp=segment.created_at
|
||
or user_message_timestamp,
|
||
)
|
||
else:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": "语音转写失败,请重试或使用文字输入"
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.exception("处理音频消息失败: {}", e)
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": "语音处理失败,请重试或使用文字输入"
|
||
},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
|
||
elif msg_type == MessageType.TRANSCRIBE_ONLY:
|
||
data = message.get("data", {})
|
||
audio_base64 = data.get("audio_base64", "")
|
||
if not audio_base64:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "缺少 audio_base64"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
continue
|
||
try:
|
||
asr = get_asr_provider()
|
||
audio_bytes = base64.b64decode(audio_base64)
|
||
asr_text = await asr.transcribe(audio_bytes, "m4a")
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.TRANSCRIPT,
|
||
"conversation_id": conversation_id,
|
||
"data": {"text": asr_text or ""},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception as e:
|
||
logger.exception("仅转写失败: {}", e)
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "语音转写失败,请重试"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
elif msg_type == MessageType.TTS_CANCEL:
|
||
bump_tts_cancel_epoch(conversation_id)
|
||
|
||
elif msg_type == MessageType.END_CONVERSATION:
|
||
await conversation_service.end(conversation_id, user_id)
|
||
|
||
await process_conversation_segments(
|
||
conversation_id, db, quota_service
|
||
)
|
||
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.END_CONVERSATION,
|
||
"conversation_id": conversation_id,
|
||
"data": {"status": "ended"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
break
|
||
|
||
elif msg_type == MessageType.PING:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.PONG,
|
||
"conversation_id": conversation_id,
|
||
"data": {},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
else:
|
||
if msg_type is not None:
|
||
logger.warning(
|
||
"WebSocket 未识别的消息 type={} conversation_id={}",
|
||
msg_type,
|
||
conversation_id,
|
||
)
|
||
|
||
except RuntimeError as e:
|
||
error_msg = str(e)
|
||
if (
|
||
"disconnect" in error_msg.lower()
|
||
or 'Cannot call "receive"' in error_msg
|
||
or "accept" in error_msg.lower()
|
||
and "not connected" in error_msg.lower()
|
||
):
|
||
logger.debug(
|
||
"WebSocket 连接已断开或未就绪: conversation_id={} error={}",
|
||
conversation_id,
|
||
error_msg,
|
||
)
|
||
break
|
||
else:
|
||
logger.exception("处理消息时发生 RuntimeError: {}", e)
|
||
if conversation_id in manager.active_connections:
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "处理失败,请重试"},
|
||
"timestamp": datetime.now(
|
||
timezone.utc
|
||
).isoformat(),
|
||
},
|
||
)
|
||
except Exception as send_error:
|
||
logger.warning("发送错误消息失败: {}", send_error)
|
||
break
|
||
except WebSocketDisconnect as disc:
|
||
logger.info(
|
||
"WebSocket 断开连接(收消息循环): conversation_id={} code={}",
|
||
conversation_id,
|
||
getattr(disc, "code", None),
|
||
)
|
||
break
|
||
except Exception as e:
|
||
logger.exception("处理消息时发生错误: {}", e)
|
||
if conversation_id in manager.active_connections:
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "处理失败,请重试"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception as send_error:
|
||
logger.warning("发送错误消息失败: {}", send_error)
|
||
break
|
||
|
||
except WebSocketDisconnect as disc:
|
||
logger.info(
|
||
"WebSocket 断开连接: conversation_id={} code={}",
|
||
conversation_id,
|
||
getattr(disc, "code", None),
|
||
)
|
||
await manager.disconnect(conversation_id)
|
||
cleanup_segment_states(conversation_id)
|
||
except Exception as e:
|
||
logger.exception("WebSocket 端点发生错误: {}", e)
|
||
await manager.disconnect(conversation_id)
|
||
cleanup_segment_states(conversation_id)
|
||
finally:
|
||
await manager.disconnect(conversation_id)
|
||
cleanup_segment_states(conversation_id)
|