Files
life-echo/api/app/features/conversation/ws/router.py
Kevin 07979bfb09 feat(api): use Tencent ASR flash with 16k_zh_large and dev transcript logs
Replace CreateRecTask polling with recording-file flash API, add TENCENT_APP_ID,
remove server-side pydub slicing, and log ASR recognition text at INFO in development.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 11:28:22 +08:00

879 lines
41 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WebSocket 路由:实时对话通信
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
"""
import asyncio
import base64
from datetime import datetime, timezone
from fastapi import WebSocket, WebSocketDisconnect, status
from starlette.websockets import WebSocketState
from app.agents.chat.background_voice import infer_background_voice
from app.agents.chat.prompts_profile import format_user_profile_context
from app.core.agent_logging import log_asr_transcript_result
from app.core.config import settings
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider
from app.core.logging import get_logger
from app.core.security import verify_token
from app.features.conversation.service import ConversationService
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import MessageType
from app.features.conversation.ws.pipeline import (
_delayed_listening_feedback,
_voice_session_id_from_client_segment_id,
bump_tts_cancel_epoch,
chat_orchestrator,
cleanup_segment_states,
get_or_create_segment_state,
handle_tts_request_on_demand,
process_audio_segment,
process_conversation_segments,
process_persisted_user_segment_response,
register_segment_task,
register_user_response_task,
)
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
from app.features.conversation.ws.quota_guard import check_ws_quota
from app.features.conversation.ws.topic_chips_push import maybe_send_topic_chips_ws
from app.features.memoir.service import MemoirService
from app.features.quota.service import QuotaService
from app.features.user.service import UserService
from app.features.conversation.constants import chat
logger = get_logger(__name__)
def _idle_hours_since(ts) -> float | None:
"""计算距 ts 的小时数ts 为 None 或非 datetime 时返回 None。"""
if ts is None:
return None
if not isinstance(ts, datetime):
return None
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
delta = datetime.now(timezone.utc) - ts
return max(0.0, delta.total_seconds() / 3600.0)
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str,
):
"""
WebSocket 端点:处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
"""
token = websocket.query_params.get("token")
if not token:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌"
)
return
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌"
)
return
user_id = payload.get("sub")
if not user_id:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容"
)
return
async with AsyncSessionLocal() as db:
user_service = UserService(db)
user = await user_service.get_by_id(user_id)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在"
)
return
await manager.connect(websocket, conversation_id)
logger.info(
"WebSocket 已连接 conversation_id={} user_id={}",
conversation_id,
user_id,
)
quota_service = QuotaService(db=db)
conversation_service = ConversationService(db=db, quota_service=quota_service)
memoir_service = MemoirService(db=db)
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
conversation, ws_conn_err = await conversation_service.ensure_ws_connection(
conversation_id, user_id
)
if ws_conn_err == "forbidden":
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
)
return
if ws_conn_err == "deleted":
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "对话已删除"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除"
)
return
# 冷启动对齐 conversation_stage 与 MemoirState.current_stage
# 若对话行已有更靠前的人生阶段STAGE_TO_ORDER 更大),不覆盖以免回退。
memoir_state = await memoir_service.get_or_create_memoir_state(user_id)
await conversation_service.align_conversation_stage_from_memoir(
conversation, memoir_state.current_stage or ""
)
await db.refresh(conversation)
history = await conversation_service.ensure_redis_history_from_db(
conversation_id
)
user_language = (
"en"
if str(getattr(user, "language_preference", "zh") or "zh").lower() == "en"
else "zh"
)
async def _stream_ai_only_messages(
texts: list[str], log_label: str
) -> None:
"""统一:把一组 AI 消息落库并按 [SPLIT] 分段下发。"""
if not texts:
return
ai_msg_id = await conversation_service.record_ai_only_turn(
conversation_id, texts
)
if not ai_msg_id:
return
total_n = len(texts)
for i, text in enumerate(texts):
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": text,
"index": i,
"total": total_n,
"assistant_message_id": ai_msg_id,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
if i < total_n - 1:
await asyncio.sleep(0.5)
logger.info(
"event=ws_auto_ai_sent label={} conversation_id={} segments={}",
log_label,
conversation_id,
total_n,
)
async def _maybe_send_topic_chips(reason: str) -> None:
await maybe_send_topic_chips_ws(
conversation_id,
user=user,
memoir_state=memoir_state,
reason=reason,
language=user_language,
)
if not history:
missing_profile = get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await chat_orchestrator.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
language=user_language,
)
await _stream_ai_only_messages(
greetings, log_label="profile_greeting"
)
except Exception as e:
logger.exception("发送资料收集开场白失败: {}", e)
else:
try:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
language=user_language,
)
era_place = (user.grew_up_place or user.birth_place or "") or ""
opening_messages = (
await chat_orchestrator.generate_opening_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
background_voice=infer_background_voice(
user.occupation
),
occupation=user.occupation or "",
profile_birth_year=user.birth_year,
profile_era_place=era_place,
language=user_language,
)
)
await _stream_ai_only_messages(
opening_messages, log_label="opening"
)
await _maybe_send_topic_chips(reason="opening")
except Exception as e:
logger.exception("发送空对话开场白失败: {}", e)
else:
# 历史非空:判断是否需要回访问候(距上次消息超过阈值)
idle_hours = _idle_hours_since(conversation.last_message_at)
threshold = float(chat.re_greeting_idle_hours)
if (
chat.re_greeting_enabled
and not get_missing_profile_fields(user)
and idle_hours is not None
and idle_hours >= threshold
):
try:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
language=user_language,
)
era_place = (user.grew_up_place or user.birth_place or "") or ""
re_greetings = (
await chat_orchestrator.generate_re_greeting_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
idle_hours=idle_hours,
user_profile_context=user_profile_context,
background_voice=infer_background_voice(
user.occupation
),
occupation=user.occupation or "",
profile_birth_year=user.birth_year,
profile_era_place=era_place,
language=user_language,
)
)
await _stream_ai_only_messages(
re_greetings, log_label="re_greeting"
)
logger.info(
"event=ws_re_greeting_emitted conversation_id={} "
"idle_hours={:.2f} threshold={:.2f}",
conversation_id,
idle_hours,
threshold,
)
await _maybe_send_topic_chips(reason="re_greeting")
except Exception as e:
logger.exception("发送回访问候失败: {}", e)
else:
# 不触发回访问候时,仍可下发 chips 以减少冷启动门槛
await _maybe_send_topic_chips(reason="resume")
while True:
try:
if websocket.application_state != WebSocketState.CONNECTED:
logger.debug(
"WebSocket 已非连接状态,退出循环: conversation_id={}",
conversation_id,
)
break
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.AUDIO_SEGMENT:
_d = message.get("data") or {}
logger.info(
"WebSocket 收到消息 type={} conversation_id={} "
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
msg_type,
conversation_id,
_d.get("segment_index"),
bool(_d.get("is_last")),
int(_d.get("duration") or 0),
len(_d.get("audio_base64") or ""),
)
elif msg_type is not None:
logger.info(
"WebSocket 收到消息 type={} conversation_id={}",
msg_type,
conversation_id,
)
else:
logger.warning(
"WebSocket 收到缺少 type 的 JSON conversation_id={}",
conversation_id,
)
if msg_type == MessageType.TEXT:
data = message.get("data") or {}
text_message = data.get("text", "")
tts_this_turn = bool(data.get("tts_this_turn"))
if text_message:
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
segment = await conversation_service.create_user_segment(
conversation,
user_id,
text_message,
)
task = asyncio.create_task(
process_persisted_user_segment_response(
conversation_id=conversation_id,
user_id=user_id,
segment_id=segment.id,
tts_this_turn=tts_this_turn,
)
)
register_user_response_task(conversation_id, task)
elif msg_type == MessageType.RECORDING_STARTED:
data = message.get("data", {})
raw_vs = data.get("voice_session_id")
if not raw_vs or not str(raw_vs).strip():
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 voice_session_id"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
voice_session_id = str(raw_vs).strip()
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
async with segment_state.lock:
if (
segment_state.listening_feedback_task is not None
and not segment_state.listening_feedback_task.done()
):
continue
if segment_state.listening_feedback_sent:
continue
delayed_task = asyncio.create_task(
_delayed_listening_feedback(
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
)
segment_state.listening_feedback_task = delayed_task
elif msg_type == MessageType.AUDIO_SEGMENT:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
segment_index_raw = data.get("segment_index")
resolved_vs = data.get("voice_session_id") or (
_voice_session_id_from_client_segment_id(
data.get("client_segment_id")
)
)
if not resolved_vs or not str(resolved_vs).strip():
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "缺少 voice_session_id 或有效的 client_segment_id"
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
voice_session_id = str(resolved_vs).strip()
is_last = bool(data.get("is_last", False))
audio_duration = int(data.get("duration", 0) or 0)
tts_this_turn_segment = bool(data.get("tts_this_turn"))
if not audio_base64:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index_raw is None:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 segment_index"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
segment_index = int(segment_index_raw)
except (TypeError, ValueError):
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 必须为整数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index < 0:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 不能为负数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
should_process = False
async with segment_state.lock:
already_seen = (
segment_index in segment_state.pending_indices
or segment_index in segment_state.processed_indices
or segment_index <= segment_state.consumed_index
)
if not already_seen:
segment_state.pending_indices.add(segment_index)
should_process = True
if not should_process:
logger.debug(
"收到重复分段,跳过: conversation_id={} voice_session_id={} "
"segment_index={} audio_b64_len={} duration={}",
conversation_id,
voice_session_id,
segment_index,
len(audio_base64 or ""),
audio_duration,
)
continue
if is_last:
async with segment_state.lock:
t = segment_state.listening_feedback_task
segment_state.listening_feedback_task = None
if t is not None and not t.done():
t.cancel()
task = asyncio.create_task(
process_audio_segment(
conversation_id=conversation_id,
user_id=user_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
audio_base64=audio_base64,
audio_duration=audio_duration,
is_last=is_last,
tts_this_turn=tts_this_turn_segment,
)
)
register_segment_task(conversation_id, voice_session_id, task)
elif msg_type == MessageType.AUDIO_MESSAGE:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
audio_duration = data.get("duration", 0)
tts_this_turn = bool(data.get("tts_this_turn"))
if audio_base64:
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
logger.debug(
"收到音频消息: conversation_id={} duration_s={}",
conversation_id,
audio_duration,
)
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
asr_text = await asr.transcribe(audio_bytes, "m4a")
log_asr_transcript_result(
logger,
text=asr_text or "",
conversation_id=conversation_id,
duration_s=audio_duration,
source="audio_message",
)
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": asr_text,
"audio_duration": audio_duration,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
try:
ads = int(audio_duration)
except (TypeError, ValueError):
ads = 0
segment = (
await conversation_service.create_user_segment(
conversation,
user_id,
asr_text,
audio_url=f"audio:{audio_duration}s",
audio_duration_seconds=ads if ads > 0 else None,
)
)
if asr_text and not asr_text.startswith("转写失败"):
task = asyncio.create_task(
process_persisted_user_segment_response(
conversation_id=conversation_id,
user_id=user_id,
segment_id=segment.id,
tts_this_turn=tts_this_turn,
)
)
register_user_response_task(conversation_id, task)
else:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音转写失败,请重试或使用文字输入"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as e:
logger.exception("处理音频消息失败: {}", e)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音处理失败,请重试或使用文字输入"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
elif msg_type == MessageType.TRANSCRIBE_ONLY:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
if not audio_base64:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
asr_text = await asr.transcribe(audio_bytes, "m4a")
log_asr_transcript_result(
logger,
text=asr_text or "",
conversation_id=conversation_id,
source="transcribe_only",
)
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": asr_text or ""},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as e:
logger.exception("仅转写失败: {}", e)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "语音转写失败,请重试"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
elif msg_type == MessageType.TTS_CANCEL:
bump_tts_cancel_epoch(conversation_id)
elif msg_type == MessageType.TTS_REQUEST:
data = message.get("data") or {}
aid = data.get("assistant_message_id") or data.get(
"assistantMessageId"
)
if not aid or not str(aid).strip():
logger.warning(
"ws.TTS_REQUEST 缺少 assistant_message_id "
"conversation_id={} user_id={}",
conversation_id,
user_id,
)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少助手消息 id"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
seg_idx = int(
data.get("segment_index", data.get("segmentIndex", 0))
)
except (TypeError, ValueError):
seg_idx = 0
st = data.get("segment_text") or data.get("segmentText")
st_val: str | None
if st is None:
st_val = None
else:
st_val = str(st).strip() or None
ok, err_msg = await handle_tts_request_on_demand(
conversation_id=conversation_id,
user_id=user_id,
assistant_message_id=str(aid).strip(),
segment_index=seg_idx,
segment_text=st_val,
db=db,
)
if not ok:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": err_msg or "朗读请求失败"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
elif msg_type == MessageType.END_CONVERSATION:
await conversation_service.end(conversation_id, user_id)
await process_conversation_segments(
conversation_id, db, quota_service
)
await manager.send_message(
conversation_id,
{
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
break
elif msg_type == MessageType.PING:
await manager.send_message(
conversation_id,
{
"type": MessageType.PONG,
"conversation_id": conversation_id,
"data": {},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
else:
if msg_type is not None:
logger.warning(
"WebSocket 未识别的消息 type={} conversation_id={}",
msg_type,
conversation_id,
)
except RuntimeError as e:
error_msg = str(e)
if (
"disconnect" in error_msg.lower()
or 'Cannot call "receive"' in error_msg
or "accept" in error_msg.lower()
and "not connected" in error_msg.lower()
):
logger.debug(
"WebSocket 连接已断开或未就绪: conversation_id={} error={}",
conversation_id,
error_msg,
)
break
else:
logger.exception("处理消息时发生 RuntimeError: {}", e)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "处理失败,请重试"},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as send_error:
logger.warning("发送错误消息失败: {}", send_error)
break
except WebSocketDisconnect as disc:
logger.info(
"WebSocket 断开连接(收消息循环): conversation_id={} code={}",
conversation_id,
getattr(disc, "code", None),
)
break
except Exception as e:
logger.exception("处理消息时发生错误: {}", e)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "处理失败,请重试"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as send_error:
logger.warning("发送错误消息失败: {}", send_error)
break
except WebSocketDisconnect as disc:
logger.info(
"WebSocket 断开连接: conversation_id={} code={}",
conversation_id,
getattr(disc, "code", None),
)
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)
except Exception as e:
logger.exception("WebSocket 端点发生错误: {}", e)
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)
finally:
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)