1500 lines
54 KiB
Python
1500 lines
54 KiB
Python
import asyncio
|
||
import unittest
|
||
from contextlib import ExitStack
|
||
from dataclasses import dataclass
|
||
from types import SimpleNamespace
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
from starlette.websockets import WebSocketDisconnect, WebSocketState
|
||
|
||
from app.features.conversation import service as conversation_feature_service
|
||
from app.features.conversation.models import Conversation, Segment
|
||
from app.features.conversation.ws import router as ws_router
|
||
|
||
|
||
class _FakeWebSocket:
|
||
def __init__(self, messages, token="valid-token"):
|
||
self.query_params = {}
|
||
if token is not None:
|
||
self.query_params["token"] = token
|
||
self._messages = list(messages)
|
||
self.application_state = WebSocketState.CONNECTED
|
||
self.close_calls = []
|
||
|
||
async def receive_json(self):
|
||
if not self._messages:
|
||
raise WebSocketDisconnect()
|
||
next_item = self._messages.pop(0)
|
||
if isinstance(next_item, BaseException):
|
||
raise next_item
|
||
return next_item
|
||
|
||
async def close(self, code=None, reason=None):
|
||
self.close_calls.append({"code": code, "reason": reason})
|
||
self.application_state = WebSocketState.DISCONNECTED
|
||
|
||
|
||
@dataclass
|
||
class _ScalarsResult:
|
||
_items: list
|
||
|
||
def all(self):
|
||
return list(self._items)
|
||
|
||
|
||
@dataclass
|
||
class _ExecuteResult:
|
||
_items: list
|
||
|
||
def scalars(self):
|
||
return _ScalarsResult(self._items)
|
||
|
||
def scalar_one_or_none(self):
|
||
return self._items[0] if len(self._items) == 1 else None
|
||
|
||
|
||
class _FakeAsyncDB:
|
||
def __init__(self, user, conversation=None, segments=None):
|
||
self.user = user
|
||
self.conversation = conversation
|
||
self.segments = list(segments or [])
|
||
self.added = []
|
||
self.commit_calls = 0
|
||
self.refresh_calls = 0
|
||
self.state_result = None # 若为 None,get_or_create_state 的查询返回空
|
||
|
||
async def get(self, model, key):
|
||
model_name = getattr(model, "__name__", "")
|
||
if model_name == "User":
|
||
return self.user
|
||
if model_name == "Conversation":
|
||
if self.conversation and self.conversation.id == key:
|
||
return self.conversation
|
||
return None
|
||
return None
|
||
|
||
def add(self, obj):
|
||
self.added.append(obj)
|
||
if isinstance(obj, Conversation):
|
||
self.conversation = obj
|
||
if isinstance(obj, Segment):
|
||
self.segments.append(obj)
|
||
|
||
async def commit(self):
|
||
self.commit_calls += 1
|
||
|
||
async def refresh(self, obj):
|
||
_ = obj
|
||
self.refresh_calls += 1
|
||
|
||
async def execute(self, stmt):
|
||
stmt_str = str(stmt)
|
||
if "MemoirState" in stmt_str or "memoir_state" in stmt_str:
|
||
return _ExecuteResult(
|
||
[self.state_result] if self.state_result is not None else []
|
||
)
|
||
return _ExecuteResult(self.segments)
|
||
|
||
|
||
class _FakeManager:
|
||
def __init__(self):
|
||
self.active_connections = {}
|
||
self.segment_states = {}
|
||
self.sent_messages = []
|
||
self.disconnect_calls = []
|
||
self.background_runner = SimpleNamespace(
|
||
queue_message=AsyncMock(),
|
||
flush_pending=AsyncMock(),
|
||
)
|
||
self.conversation_agent = SimpleNamespace(
|
||
generate_profile_greeting=AsyncMock(return_value=[]),
|
||
generate_opening_message=AsyncMock(return_value=[]),
|
||
)
|
||
|
||
async def connect(self, websocket, conversation_id):
|
||
self.active_connections[conversation_id] = websocket
|
||
|
||
async def disconnect(self, conversation_id):
|
||
self.disconnect_calls.append(conversation_id)
|
||
self.active_connections.pop(conversation_id, None)
|
||
|
||
async def send_message(self, conversation_id, message):
|
||
self.sent_messages.append(
|
||
{"conversation_id": conversation_id, "message": message}
|
||
)
|
||
|
||
def get_or_create_segment_state(self, conversation_id, voice_session_id):
|
||
state_key = (conversation_id, voice_session_id)
|
||
if state_key not in self.segment_states:
|
||
self.segment_states[state_key] = ws_router.SegmentStreamState()
|
||
return self.segment_states[state_key]
|
||
|
||
def register_segment_task(self, conversation_id, voice_session_id, task):
|
||
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
|
||
state.active_tasks.add(task)
|
||
|
||
def _cleanup(done_task):
|
||
state.active_tasks.discard(done_task)
|
||
|
||
task.add_done_callback(_cleanup)
|
||
|
||
|
||
def _make_user():
|
||
# Provide all profile fields to skip greeting/profile-collection branch.
|
||
return SimpleNamespace(
|
||
id="user-1",
|
||
nickname="tester",
|
||
subscription_type="premium",
|
||
birth_year=1990,
|
||
birth_place="A",
|
||
grew_up_place="B",
|
||
occupation="dev",
|
||
)
|
||
|
||
|
||
def _db_provider(db):
|
||
"""返回可被 patch 到 get_async_db 的异步生成器(旧用法)。"""
|
||
|
||
async def _provider():
|
||
yield db
|
||
|
||
return _provider
|
||
|
||
|
||
class _FakeSessionCM:
|
||
"""模拟 async with AsyncSessionLocal() as db 的上下文管理器。"""
|
||
|
||
def __init__(self, db):
|
||
self._db = db
|
||
|
||
async def __aenter__(self):
|
||
return self._db
|
||
|
||
async def __aexit__(self, *args):
|
||
pass
|
||
|
||
|
||
def _session_local_factory(fake_db):
|
||
"""返回可 patch 到 AsyncSessionLocal 的工厂,使 async with AsyncSessionLocal() as db 得到 fake_db。"""
|
||
|
||
def _factory():
|
||
return _FakeSessionCM(fake_db)
|
||
|
||
return _factory
|
||
|
||
|
||
def _redis_empty_history_patch():
|
||
"""Patch redis to return empty history so websocket sends opening (or skips if mocked)."""
|
||
return patch.object(
|
||
conversation_feature_service.redis_service,
|
||
"get_conversation_history",
|
||
new=AsyncMock(return_value=[]),
|
||
)
|
||
|
||
|
||
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||
async def test_invalid_token_closes_connection(self):
|
||
websocket = _FakeWebSocket(messages=[], token="invalid")
|
||
|
||
with patch.object(ws_router, "verify_token", return_value=None):
|
||
await ws_router.websocket_endpoint(websocket, "conv-1")
|
||
|
||
self.assertEqual(len(websocket.close_calls), 1)
|
||
self.assertEqual(websocket.close_calls[0]["reason"], "无效的认证令牌")
|
||
|
||
async def test_text_message_creates_segment_and_dispatches_agent(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{"type": "text", "data": {"text": "你好"}},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
||
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
|
||
self.assertEqual(len(segments), 1)
|
||
self.assertEqual(segments[0].transcript_text, "你好")
|
||
self.assertIsNone(segments[0].audio_url)
|
||
self.assertIsNotNone(conversation.last_message_at)
|
||
fake_manager.background_runner.queue_message.assert_awaited_once()
|
||
process_user_message_mock.assert_awaited_once()
|
||
|
||
message_types = [item["message"]["type"] for item in fake_manager.sent_messages]
|
||
self.assertIn(ws_router.MessageType.CONNECT, message_types)
|
||
|
||
async def test_audio_message_transcribes_then_calls_agent(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_message",
|
||
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 12},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(return_value="这是转写结果")
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
||
transcribe_mock.assert_awaited_once_with(b"fake-audio-b64", "m4a")
|
||
process_user_message_mock.assert_awaited_once()
|
||
|
||
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
|
||
self.assertEqual(len(segments), 1)
|
||
self.assertEqual(segments[0].transcript_text, "这是转写结果")
|
||
self.assertEqual(segments[0].audio_url, "audio:12s")
|
||
self.assertIsNotNone(conversation.last_message_at)
|
||
|
||
transcript_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
||
]
|
||
self.assertEqual(len(transcript_msgs), 1)
|
||
self.assertEqual(transcript_msgs[0]["data"]["text"], "这是转写结果")
|
||
|
||
async def test_transcribe_only_returns_transcript_without_segment_or_agent(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "transcribe_only",
|
||
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ="},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(return_value="仅转写文本")
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
||
transcribe_mock.assert_awaited_once_with(b"fake-audio-b64", "m4a")
|
||
process_user_message_mock.assert_not_awaited()
|
||
self.assertEqual(
|
||
len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0
|
||
)
|
||
|
||
transcript_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
||
]
|
||
self.assertEqual(len(transcript_msgs), 1)
|
||
self.assertEqual(transcript_msgs[0]["data"]["text"], "仅转写文本")
|
||
|
||
async def test_end_conversation_updates_status_and_triggers_processing(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[{"type": "end_conversation", "conversation_id": "conv-1"}]
|
||
)
|
||
process_conversation_segments_mock = AsyncMock()
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"process_conversation_segments",
|
||
process_conversation_segments_mock,
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
||
self.assertEqual(conversation.status, "ended")
|
||
process_conversation_segments_mock.assert_awaited_once()
|
||
end_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.END_CONVERSATION
|
||
]
|
||
self.assertEqual(len(end_msgs), 1)
|
||
|
||
async def test_transcribe_only_missing_audio_returns_error(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{"type": "transcribe_only", "data": {}},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
transcribe_mock = AsyncMock(return_value="should-not-be-called")
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
||
transcribe_mock.assert_not_awaited()
|
||
error_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.ERROR
|
||
]
|
||
self.assertEqual(len(error_msgs), 1)
|
||
self.assertEqual(error_msgs[0]["data"]["message"], "缺少 audio_base64")
|
||
|
||
async def test_audio_message_transcribe_failure_sends_error_and_skips_agent(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_message",
|
||
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 8},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(return_value="转写失败: mock error")
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
||
process_user_message_mock.assert_not_awaited()
|
||
error_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.ERROR
|
||
]
|
||
self.assertEqual(len(error_msgs), 1)
|
||
self.assertEqual(
|
||
error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入"
|
||
)
|
||
|
||
async def test_audio_segment_out_of_order_is_aggregated_by_segment_index(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "seg-1",
|
||
"voice_session_id": "voice-session-1",
|
||
"segment_index": 1,
|
||
"duration": 12,
|
||
"is_last": False,
|
||
},
|
||
},
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "seg-0",
|
||
"voice_session_id": "voice-session-1",
|
||
"segment_index": 0,
|
||
"duration": 10,
|
||
"is_last": False,
|
||
},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(side_effect=["这是第 1 段", "这是第 0 段"])
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
await asyncio.sleep(0.05)
|
||
|
||
self.assertEqual(transcribe_mock.await_count, 2)
|
||
ordered_messages = [
|
||
call.kwargs["user_message"]
|
||
for call in process_user_message_mock.await_args_list
|
||
]
|
||
self.assertEqual(ordered_messages, ["这是第 0 段", "这是第 1 段"])
|
||
self.assertEqual(
|
||
len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2
|
||
)
|
||
transcript_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
||
]
|
||
self.assertEqual(
|
||
[msg["data"]["voice_session_id"] for msg in transcript_msgs],
|
||
["voice-session-1", "voice-session-1"],
|
||
)
|
||
|
||
async def test_audio_segment_duplicate_index_is_idempotent(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "dup-seg-0-a",
|
||
"segment_index": 0,
|
||
"duration": 10,
|
||
"is_last": False,
|
||
},
|
||
},
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "dup-seg-0-b",
|
||
"segment_index": 0,
|
||
"duration": 10,
|
||
"is_last": True,
|
||
},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(return_value="重复分段去重测试")
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
await asyncio.sleep(0.05)
|
||
|
||
self.assertEqual(transcribe_mock.await_count, 1)
|
||
process_user_message_mock.assert_awaited_once()
|
||
self.assertEqual(
|
||
len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1
|
||
)
|
||
|
||
async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions(
|
||
self,
|
||
):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "session-a-seg-0",
|
||
"voice_session_id": "voice-session-a",
|
||
"client_segment_id": "voice-session-a-0",
|
||
"segment_index": 0,
|
||
"duration": 10,
|
||
"is_last": True,
|
||
},
|
||
},
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "session-b-seg-0",
|
||
"voice_session_id": "voice-session-b",
|
||
"client_segment_id": "voice-session-b-0",
|
||
"segment_index": 0,
|
||
"duration": 8,
|
||
"is_last": True,
|
||
},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(side_effect=["第一轮第 0 段", "第二轮第 0 段"])
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
await asyncio.sleep(0.05)
|
||
|
||
ordered_messages = [
|
||
call.kwargs["user_message"]
|
||
for call in process_user_message_mock.await_args_list
|
||
]
|
||
self.assertEqual(ordered_messages, ["第一轮第 0 段", "第二轮第 0 段"])
|
||
self.assertEqual(transcribe_mock.await_count, 2)
|
||
self.assertEqual(
|
||
len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2
|
||
)
|
||
|
||
async def test_audio_segment_sends_transition_feedback_while_processing(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "slow-seg-0",
|
||
"segment_index": 0,
|
||
"duration": 20,
|
||
"is_last": True,
|
||
},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
async def _slow_transcribe(_: str = None, **kwargs) -> str:
|
||
await asyncio.sleep(0.2)
|
||
return "慢速转写"
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(side_effect=_slow_transcribe)
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
await asyncio.sleep(0.05)
|
||
|
||
# 当前逻辑:仅首个分段且 5s 后发一次「我在认真听」;若本段 is_last 则取消,故此处应为 0
|
||
transition_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
||
and item["message"].get("data", {}).get("transition") is True
|
||
]
|
||
self.assertEqual(len(transition_msgs), 0)
|
||
|
||
async def test_recording_started_sends_listening_feedback_after_delay(self):
|
||
"""客户端发送 recording_started 后,延迟 5s 发一次「我在认真听」。"""
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "recording_started",
|
||
"data": {"voice_session_id": "session-1"},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.LISTENING_FEEDBACK_DELAY_SEC",
|
||
0.05,
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
await asyncio.sleep(0.12)
|
||
|
||
transition_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
||
and item["message"].get("data", {}).get("transition") is True
|
||
]
|
||
self.assertEqual(len(transition_msgs), 1)
|
||
self.assertIn("我在认真听", transition_msgs[0]["data"].get("text", ""))
|
||
|
||
async def test_audio_segment_last_segment_does_not_emit_terminal_transition(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "last-seg-0",
|
||
"voice_session_id": "voice-session-last",
|
||
"client_segment_id": "voice-session-last-0",
|
||
"segment_index": 0,
|
||
"duration": 15,
|
||
"is_last": True,
|
||
},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(return_value="最后一段转写")
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
await asyncio.sleep(0.05)
|
||
|
||
# 仅一段且 is_last:延迟任务被取消,不应发出 transition
|
||
transition_msgs = [
|
||
item["message"]
|
||
for item in fake_manager.sent_messages
|
||
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
||
and item["message"].get("data", {}).get("transition") is True
|
||
]
|
||
self.assertEqual(len(transition_msgs), 0)
|
||
|
||
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(
|
||
self,
|
||
):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
existing_segment = Segment(
|
||
id="seg-existing-0",
|
||
conversation_id="conv-1",
|
||
transcript_text="已存在的上一段",
|
||
audio_url="audio-segment:voice-session-1:0",
|
||
processed=False,
|
||
)
|
||
fake_db = _FakeAsyncDB(
|
||
user=user,
|
||
conversation=conversation,
|
||
segments=[existing_segment],
|
||
)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "seg-1-after-reconnect",
|
||
"voice_session_id": "voice-session-1",
|
||
"client_segment_id": "voice-session-1-1",
|
||
"segment_index": 1,
|
||
"duration": 18,
|
||
"is_last": True,
|
||
},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(return_value="这是重连后的第 1 段")
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
await asyncio.sleep(0.05)
|
||
|
||
process_user_message_mock.assert_awaited_once()
|
||
self.assertEqual(
|
||
process_user_message_mock.await_args.kwargs["user_message"],
|
||
"这是重连后的第 1 段",
|
||
)
|
||
|
||
async def test_audio_segment_reconnect_uses_contiguous_prefix_not_max_index(self):
|
||
user = _make_user()
|
||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
existing_segments = [
|
||
Segment(
|
||
id="seg-existing-0",
|
||
conversation_id="conv-1",
|
||
transcript_text="已存在的第 0 段",
|
||
audio_url="audio-segment:voice-session-gap:0",
|
||
processed=False,
|
||
),
|
||
Segment(
|
||
id="seg-existing-2",
|
||
conversation_id="conv-1",
|
||
transcript_text="已存在的第 2 段",
|
||
audio_url="audio-segment:voice-session-gap:2",
|
||
processed=False,
|
||
),
|
||
]
|
||
fake_db = _FakeAsyncDB(
|
||
user=user,
|
||
conversation=conversation,
|
||
segments=existing_segments,
|
||
)
|
||
fake_manager = _FakeManager()
|
||
fake_websocket = _FakeWebSocket(
|
||
messages=[
|
||
{
|
||
"type": "audio_segment",
|
||
"data": {
|
||
"audio_base64": "seg-1-gap-retry",
|
||
"voice_session_id": "voice-session-gap",
|
||
"client_segment_id": "voice-session-gap-1",
|
||
"segment_index": 1,
|
||
"duration": 18,
|
||
"is_last": False,
|
||
},
|
||
},
|
||
WebSocketDisconnect(),
|
||
]
|
||
)
|
||
|
||
process_user_message_mock = AsyncMock()
|
||
transcribe_mock = AsyncMock(return_value="补传后的第 1 段")
|
||
|
||
with ExitStack() as stack:
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"verify_token",
|
||
return_value={"type": "access", "sub": user.id},
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.AsyncSessionLocal",
|
||
_session_local_factory(fake_db),
|
||
)
|
||
)
|
||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch("app.features.conversation.ws.pipeline.manager", fake_manager)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "background_runner", fake_manager.background_runner
|
||
)
|
||
)
|
||
stack.enter_context(_redis_empty_history_patch())
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.router.check_ws_quota",
|
||
new=AsyncMock(return_value=(True, "")),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router, "process_user_message", process_user_message_mock
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.process_user_message",
|
||
process_user_message_mock,
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch.object(
|
||
ws_router,
|
||
"get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
stack.enter_context(
|
||
patch(
|
||
"app.features.conversation.ws.pipeline.get_asr_provider",
|
||
lambda: SimpleNamespace(transcribe=transcribe_mock),
|
||
)
|
||
)
|
||
|
||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
await asyncio.sleep(0.05)
|
||
|
||
process_user_message_mock.assert_awaited_once()
|
||
self.assertEqual(
|
||
process_user_message_mock.await_args.kwargs["user_message"],
|
||
"补传后的第 1 段",
|
||
)
|
||
self.assertEqual(transcribe_mock.await_count, 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|