fix: 用户开始录音5s后ai反馈“我在认真听”

This commit is contained in:
yangshilin
2026-03-16 11:24:40 +08:00
parent 981920784f
commit 2070a03d35
9 changed files with 128 additions and 11 deletions

View File

@@ -32,6 +32,7 @@ LEGACY_VOICE_SESSION_ID = "legacy"
class MessageType(str, Enum): class MessageType(str, Enum):
"""WebSocket 消息类型""" """WebSocket 消息类型"""
CONNECT = "connect" CONNECT = "connect"
RECORDING_STARTED = "recording_started" # 客户端开始录音,用于服务端 5s 后发「我在认真听」
AUDIO_CHUNK = "audio_chunk" AUDIO_CHUNK = "audio_chunk"
AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传) AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传)
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音) AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
@@ -146,6 +147,9 @@ class SegmentStreamState:
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict) buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
consumed_index: int = -1 consumed_index: int = -1
active_tasks: Set[asyncio.Task] = field(default_factory=set) active_tasks: Set[asyncio.Task] = field(default_factory=set)
# 录音开始约 5s 后只发一次「我在认真听」;若用户提前结束录音则取消待发
listening_feedback_sent: bool = False
listening_feedback_task: Optional[asyncio.Task] = None
def _utc_now() -> datetime: def _utc_now() -> datetime:
@@ -257,17 +261,21 @@ async def _get_persisted_contiguous_segment_index(
return contiguous_index return contiguous_index
LISTENING_FEEDBACK_DELAY_SEC = 5.0
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
async def _send_segment_transition_feedback( async def _send_segment_transition_feedback(
conversation_id: str, conversation_id: str,
segment_index: int, segment_index: int,
manager: ConnectionManager, manager: ConnectionManager,
) -> None: ) -> None:
"""ASR 处理中先给陪伴式过渡反馈,避免用户感知卡住""" """发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)"""
await manager.send_message(conversation_id, { await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE, "type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id, "conversation_id": conversation_id,
"data": { "data": {
"text": "我在认真听,你继续说,我会边听边整理重点。", "text": LISTENING_FEEDBACK_TEXT,
"transition": True, "transition": True,
"segment_index": segment_index, "segment_index": segment_index,
}, },
@@ -275,6 +283,22 @@ async def _send_segment_transition_feedback(
}) })
async def _delayed_listening_feedback(
conversation_id: str,
voice_session_id: str,
manager: ConnectionManager,
) -> None:
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
async with state.lock:
if state.listening_feedback_sent:
return
state.listening_feedback_sent = True
state.listening_feedback_task = None
await _send_segment_transition_feedback(conversation_id, 0, manager)
async def _process_audio_segment_async( async def _process_audio_segment_async(
conversation_id: str, conversation_id: str,
user_id: str, user_id: str,
@@ -600,6 +624,28 @@ async def websocket_endpoint(
user_message_timestamp=segment.created_at or user_message_timestamp, user_message_timestamp=segment.created_at or user_message_timestamp,
) )
elif msg_type == MessageType.RECORDING_STARTED:
# 用户点击开始录音:启动 5s 定时器,到时发一次「我在认真听」
data = message.get("data", {})
voice_session_id = _normalize_voice_session_id(data.get("voice_session_id"))
segment_state = manager.get_or_create_segment_state(
conversation_id,
voice_session_id,
)
async with segment_state.lock:
if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done():
continue # 本会话已有待发任务,不重复
if segment_state.listening_feedback_sent:
continue
delayed_task = asyncio.create_task(
_delayed_listening_feedback(
conversation_id=conversation_id,
voice_session_id=voice_session_id,
manager=manager,
)
)
segment_state.listening_feedback_task = delayed_task
elif msg_type == MessageType.AUDIO_SEGMENT: elif msg_type == MessageType.AUDIO_SEGMENT:
# 处理分段语音消息(长语音持续上传) # 处理分段语音消息(长语音持续上传)
data = message.get("data", {}) data = message.get("data", {})
@@ -680,12 +726,13 @@ async def websocket_endpoint(
) )
continue continue
# 先发过渡反馈,减少“等待空白”体感 # 若本段是用户结束录音的最后一段,取消尚未发出的「我在认真听」,避免结束后再说
await _send_segment_transition_feedback( if is_last:
conversation_id=conversation_id, async with segment_state.lock:
segment_index=segment_index, t = segment_state.listening_feedback_task
manager=manager, segment_state.listening_feedback_task = None
) if t is not None and not t.done():
t.cancel()
task = asyncio.create_task( task = asyncio.create_task(
_process_audio_segment_async( _process_audio_segment_async(

View File

@@ -752,13 +752,56 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
await ws_router.websocket_endpoint(fake_websocket, "conv-1") await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
# 当前逻辑:仅首个分段且 5s 后发一次「我在认真听」;若本段 is_last 则取消,故此处应为 0
transition_msgs = [ transition_msgs = [
item["message"] item["message"]
for item in fake_manager.sent_messages for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
and item["message"].get("data", {}).get("transition") is True and item["message"].get("data", {}).get("transition") is True
] ]
self.assertGreaterEqual(len(transition_msgs), 1) self.assertEqual(len(transition_msgs), 0)
async def test_recording_started_sends_listening_feedback_after_delay(self):
"""客户端发送 recording_started 后,延迟 5s 发一次「我在认真听」。"""
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{"type": "recording_started", "data": {"voice_session_id": "session-1"}},
WebSocketDisconnect(),
]
)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(_redis_empty_history_patch())
stack.enter_context(
patch("routers.websocket.LISTENING_FEEDBACK_DELAY_SEC", 0.05)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.12)
transition_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
and item["message"].get("data", {}).get("transition") is True
]
self.assertEqual(len(transition_msgs), 1)
self.assertIn("我在认真听", transition_msgs[0]["data"].get("text", ""))
async def test_audio_segment_last_segment_does_not_emit_terminal_transition(self): async def test_audio_segment_last_segment_does_not_emit_terminal_transition(self):
user = _make_user() user = _make_user()
@@ -814,14 +857,14 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
await ws_router.websocket_endpoint(fake_websocket, "conv-1") await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
# 仅一段且 is_last延迟任务被取消不应发出 transition
transition_msgs = [ transition_msgs = [
item["message"] item["message"]
for item in fake_manager.sent_messages for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
and item["message"].get("data", {}).get("transition") is True and item["message"].get("data", {}).get("transition") is True
] ]
self.assertEqual(len(transition_msgs), 1) self.assertEqual(len(transition_msgs), 0)
self.assertIsNone(transition_msgs[0]["data"].get("is_last"))
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self): async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self):
user = _make_user() user = _make_user()

View File

@@ -62,6 +62,10 @@ class ConversationRealtimeAdapter(
webSocketClient.sendTextMessage(text, conversationId) webSocketClient.sendTextMessage(text, conversationId)
} }
override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) {
webSocketClient.sendRecordingStarted(conversationId, voiceSessionId)
}
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) { override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) {
webSocketClient.sendAudioChunk(chunk, conversationId) webSocketClient.sendAudioChunk(chunk, conversationId)
} }

View File

@@ -24,6 +24,7 @@ interface ConversationRealtimePort {
fun isConnected(): Boolean fun isConnected(): Boolean
suspend fun sendText(conversationId: String, text: String) suspend fun sendText(conversationId: String, text: String)
suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String)
suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String)
suspend fun sendAudioSegment(request: AudioSegmentRequest) suspend fun sendAudioSegment(request: AudioSegmentRequest)
suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int)

View File

@@ -242,6 +242,17 @@ open class WebSocketClient(
data = buildJsonObject { put("text", JsonPrimitive(text)) } data = buildJsonObject { put("text", JsonPrimitive(text)) }
)) ))
} }
suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) {
if (!_isConnected) return
sendMessage(WebSocketMessage(
type = MessageType.recording_started,
conversation_id = conversationId,
data = buildJsonObject {
put("voice_session_id", JsonPrimitive(voiceSessionId))
}
))
}
suspend fun sendAudioSegment( suspend fun sendAudioSegment(
audioBytes: ByteArray, audioBytes: ByteArray,

View File

@@ -8,6 +8,7 @@ import kotlinx.serialization.json.buildJsonObject
@Serializable @Serializable
enum class MessageType { enum class MessageType {
connect, connect,
recording_started, // 用户点击开始录音,服务端据此 5s 后发「我在认真听」
audio_chunk, audio_chunk,
audio_segment, // 分段语音(长语音边录边传) audio_segment, // 分段语音(长语音边录边传)
audio_message, // 完整音频消息(类似微信语音) audio_message, // 完整音频消息(类似微信语音)

View File

@@ -357,6 +357,13 @@ class CreateMemoryViewModel(
} }
} }
Log.d(TAG, "录音已启动: ${result.session.filePath}") Log.d(TAG, "录音已启动: ${result.session.filePath}")
// 通知服务端「开始录音」,服务端 5s 后发「我在认真听」
viewModelScope.launch {
val cid = conversationId.value
if (cid != null && cid != "new" && ensureConnected()) {
conversationRealtime.sendRecordingStarted(cid, result.session.voiceSessionId)
}
}
} }
is RecordingStartResult.AlreadyRecording -> { is RecordingStartResult.AlreadyRecording -> {
Log.w(TAG, "录音已在进行中") Log.w(TAG, "录音已在进行中")

View File

@@ -661,6 +661,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
override suspend fun disconnect() = Unit override suspend fun disconnect() = Unit
override fun isConnected(): Boolean = connected override fun isConnected(): Boolean = connected
override suspend fun sendText(conversationId: String, text: String) = Unit override suspend fun sendText(conversationId: String, text: String) = Unit
override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
override suspend fun sendAudioSegment(request: AudioSegmentRequest) { override suspend fun sendAudioSegment(request: AudioSegmentRequest) {
sendAudioSegmentFailure?.let { throw it } sendAudioSegmentFailure?.let { throw it }
@@ -705,6 +706,7 @@ class CreateMemoryViewModelRecordingCoordinatorTest {
override suspend fun sendText(conversationId: String, text: String) { override suspend fun sendText(conversationId: String, text: String) {
sentTexts += text sentTexts += text
} }
override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit
override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit

View File

@@ -290,6 +290,7 @@ class CreateMemoryViewModelWarmupTest {
override fun isConnected(): Boolean = connected override fun isConnected(): Boolean = connected
override suspend fun sendText(conversationId: String, text: String) = Unit override suspend fun sendText(conversationId: String, text: String) = Unit
override suspend fun sendRecordingStarted(conversationId: String, voiceSessionId: String) = Unit
override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit override suspend fun sendAudioChunk(chunk: ByteArray, conversationId: String) = Unit
override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit override suspend fun sendAudioSegment(request: AudioSegmentRequest) = Unit
override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit override suspend fun sendAudioMessage(audioBytes: ByteArray, conversationId: String, duration: Int) = Unit