feat: 支持长语音分段上传与断线补传

This commit is contained in:
Kevin
2026-03-09 15:30:18 +08:00
parent 440f5be07f
commit 6ffe96d7a9
13 changed files with 1451 additions and 19 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import unittest
from contextlib import ExitStack
from dataclasses import dataclass
@@ -89,6 +90,7 @@ class _FakeAsyncDB:
class _FakeManager:
def __init__(self):
self.active_connections = {}
self.segment_states = {}
self.sent_messages = []
self.disconnect_calls = []
self.background_runner = SimpleNamespace(
@@ -109,6 +111,21 @@ class _FakeManager:
async def send_message(self, conversation_id, message):
self.sent_messages.append({"conversation_id": conversation_id, "message": message})
def get_or_create_segment_state(self, conversation_id, voice_session_id):
state_key = (conversation_id, voice_session_id)
if state_key not in self.segment_states:
self.segment_states[state_key] = ws_router.SegmentStreamState()
return self.segment_states[state_key]
def register_segment_task(self, conversation_id, voice_session_id, task):
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
state.active_tasks.add(task)
def _cleanup(done_task):
state.active_tasks.discard(done_task)
task.add_done_callback(_cleanup)
def _make_user():
# Provide all profile fields to skip greeting/profile-collection branch.
@@ -435,6 +452,431 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(error_msgs), 1)
self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入")
async def test_audio_segment_out_of_order_is_aggregated_by_segment_index(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1",
"segment_index": 1,
"duration": 12,
"is_last": False,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-0",
"segment_index": 0,
"duration": 10,
"is_last": False,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(
side_effect=lambda audio: {
"seg-0": "这是第 0 段",
"seg-1": "这是第 1 段",
}[audio]
)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
self.assertEqual(transcribe_mock.await_count, 2)
ordered_messages = [
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
]
self.assertEqual(ordered_messages, ["这是第 0 段", "这是第 1 段"])
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
async def test_audio_segment_duplicate_index_is_idempotent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "dup-seg-0-a",
"segment_index": 0,
"duration": 10,
"is_last": False,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "dup-seg-0-b",
"segment_index": 0,
"duration": 10,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="重复分段去重测试")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
self.assertEqual(transcribe_mock.await_count, 1)
process_user_message_mock.assert_awaited_once()
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1)
async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "session-a-seg-0",
"voice_session_id": "voice-session-a",
"client_segment_id": "voice-session-a-0",
"segment_index": 0,
"duration": 10,
"is_last": True,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "session-b-seg-0",
"voice_session_id": "voice-session-b",
"client_segment_id": "voice-session-b-0",
"segment_index": 0,
"duration": 8,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(
side_effect=lambda audio: {
"session-a-seg-0": "第一轮第 0 段",
"session-b-seg-0": "第二轮第 0 段",
}[audio]
)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
ordered_messages = [
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
]
self.assertEqual(ordered_messages, ["第一轮第 0 段", "第二轮第 0 段"])
self.assertEqual(transcribe_mock.await_count, 2)
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
async def test_audio_segment_sends_transition_feedback_while_processing(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "slow-seg-0",
"segment_index": 0,
"duration": 20,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
async def _slow_transcribe(_: str) -> str:
await asyncio.sleep(0.2)
return "慢速转写"
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(side_effect=_slow_transcribe)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
transition_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
and item["message"].get("data", {}).get("transition") is True
]
self.assertGreaterEqual(len(transition_msgs), 1)
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
existing_segment = Segment(
id="seg-existing-0",
conversation_id="conv-1",
transcript_text="已存在的上一段",
audio_url="audio-segment:voice-session-1:0",
processed=False,
)
fake_db = _FakeAsyncDB(
user=user,
conversation=conversation,
segments=[existing_segment],
)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1-after-reconnect",
"voice_session_id": "voice-session-1",
"client_segment_id": "voice-session-1-1",
"segment_index": 1,
"duration": 18,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="这是重连后的第 1 段")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
process_user_message_mock.assert_awaited_once()
self.assertEqual(
process_user_message_mock.await_args.kwargs["user_message"],
"这是重连后的第 1 段",
)
async def test_audio_segment_reconnect_uses_contiguous_prefix_not_max_index(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
existing_segments = [
Segment(
id="seg-existing-0",
conversation_id="conv-1",
transcript_text="已存在的第 0 段",
audio_url="audio-segment:voice-session-gap:0",
processed=False,
),
Segment(
id="seg-existing-2",
conversation_id="conv-1",
transcript_text="已存在的第 2 段",
audio_url="audio-segment:voice-session-gap:2",
processed=False,
),
]
fake_db = _FakeAsyncDB(
user=user,
conversation=conversation,
segments=existing_segments,
)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1-gap-retry",
"voice_session_id": "voice-session-gap",
"client_segment_id": "voice-session-gap-1",
"segment_index": 1,
"duration": 18,
"is_last": False,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="补传后的第 1 段")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
process_user_message_mock.assert_awaited_once()
self.assertEqual(
process_user_message_mock.await_args.kwargs["user_message"],
"补传后的第 1 段",
)
self.assertEqual(transcribe_mock.await_count, 1)
if __name__ == "__main__":
unittest.main()