fix(conversation): 离屏不丢回复、列表预热 WS 与非阻塞进入聊天
- 后端:文本/转写后 AI 生成改为独立任务,避免断连取消整轮;按需 TTS 等与 WS 改动 - 前端:RealtimeSession 重绑 UI 时恢复流式 buffer;列表 onPressIn/挂载预热、已有会话立即 push - 同步会话相关类型、i18n、测试与 env/资源等累计改动 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -28,11 +28,13 @@ from app.features.conversation.ws.pipeline import (
|
||||
chat_orchestrator,
|
||||
cleanup_segment_states,
|
||||
get_or_create_segment_state,
|
||||
handle_tts_request_on_demand,
|
||||
memoir_ingest_scheduler,
|
||||
process_audio_segment,
|
||||
process_conversation_segments,
|
||||
process_user_message,
|
||||
process_persisted_user_segment_response,
|
||||
register_segment_task,
|
||||
register_user_response_task,
|
||||
)
|
||||
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
|
||||
from app.features.conversation.ws.quota_guard import check_ws_quota
|
||||
@@ -276,7 +278,9 @@ async def websocket_endpoint(
|
||||
)
|
||||
|
||||
if msg_type == MessageType.TEXT:
|
||||
text_message = message.get("data", {}).get("text", "")
|
||||
data = message.get("data") or {}
|
||||
text_message = data.get("text", "")
|
||||
tts_this_turn = bool(data.get("tts_this_turn"))
|
||||
|
||||
if text_message:
|
||||
can_send, quota_msg = await check_ws_quota(
|
||||
@@ -303,23 +307,21 @@ async def websocket_endpoint(
|
||||
user_id,
|
||||
text_message,
|
||||
)
|
||||
user_message_timestamp = conversation.last_message_at
|
||||
await memoir_ingest_scheduler.queue_segment(
|
||||
conversation.user_id,
|
||||
segment.id,
|
||||
text_char_count=len(text_message.strip()),
|
||||
)
|
||||
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=text_message,
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
user=user,
|
||||
user_message_timestamp=segment.created_at
|
||||
or user_message_timestamp,
|
||||
task = asyncio.create_task(
|
||||
process_persisted_user_segment_response(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
segment_id=segment.id,
|
||||
tts_this_turn=tts_this_turn,
|
||||
)
|
||||
)
|
||||
register_user_response_task(conversation_id, task)
|
||||
|
||||
elif msg_type == MessageType.RECORDING_STARTED:
|
||||
data = message.get("data", {})
|
||||
@@ -486,6 +488,7 @@ async def websocket_endpoint(
|
||||
audio_base64=audio_base64,
|
||||
audio_duration=audio_duration,
|
||||
is_last=is_last,
|
||||
tts_this_turn=bool(data.get("tts_this_turn")),
|
||||
)
|
||||
)
|
||||
register_segment_task(conversation_id, voice_session_id, task)
|
||||
@@ -494,6 +497,7 @@ async def websocket_endpoint(
|
||||
data = message.get("data", {})
|
||||
audio_base64 = data.get("audio_base64", "")
|
||||
audio_duration = data.get("duration", 0)
|
||||
tts_this_turn = bool(data.get("tts_this_turn"))
|
||||
|
||||
if audio_base64:
|
||||
can_send, quota_msg = await check_ws_quota(
|
||||
@@ -564,7 +568,6 @@ async def websocket_endpoint(
|
||||
audio_duration_seconds=ads if ads > 0 else None,
|
||||
)
|
||||
)
|
||||
user_message_timestamp = conversation.last_message_at
|
||||
await memoir_ingest_scheduler.queue_segment(
|
||||
conversation.user_id,
|
||||
segment.id,
|
||||
@@ -572,16 +575,15 @@ async def websocket_endpoint(
|
||||
)
|
||||
|
||||
if asr_text and not asr_text.startswith("转写失败"):
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=asr_text,
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
user=user,
|
||||
user_message_timestamp=segment.created_at
|
||||
or user_message_timestamp,
|
||||
task = asyncio.create_task(
|
||||
process_persisted_user_segment_response(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
segment_id=segment.id,
|
||||
tts_this_turn=tts_this_turn,
|
||||
)
|
||||
)
|
||||
register_user_response_task(conversation_id, task)
|
||||
else:
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
@@ -651,6 +653,51 @@ async def websocket_endpoint(
|
||||
elif msg_type == MessageType.TTS_CANCEL:
|
||||
bump_tts_cancel_epoch(conversation_id)
|
||||
|
||||
elif msg_type == MessageType.TTS_REQUEST:
|
||||
data = message.get("data") or {}
|
||||
aid = data.get("assistant_message_id") or data.get(
|
||||
"assistantMessageId"
|
||||
)
|
||||
if not aid or not str(aid).strip():
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少助手消息 id"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
try:
|
||||
seg_idx = int(
|
||||
data.get("segment_index", data.get("segmentIndex", 0))
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
seg_idx = 0
|
||||
st = data.get("segment_text") or data.get("segmentText")
|
||||
st_val: str | None
|
||||
if st is None:
|
||||
st_val = None
|
||||
else:
|
||||
st_val = str(st).strip() or None
|
||||
ok, err_msg = await handle_tts_request_on_demand(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
assistant_message_id=str(aid).strip(),
|
||||
segment_index=seg_idx,
|
||||
segment_text=st_val,
|
||||
db=db,
|
||||
)
|
||||
if not ok:
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": err_msg or "朗读请求失败"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.END_CONVERSATION:
|
||||
await conversation_service.end(conversation_id, user_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user