test: 在重构前补充 WebSocket 基线测试

补齐 WebSocket 关键流程的基线测试,为后续长语音重构提供回归保护。
This commit is contained in:
iammm0
2026-03-09 09:05:21 +08:00
parent febeaed0ae
commit 9cdd2bdf2f

View File

@@ -0,0 +1,440 @@
import unittest
from contextlib import ExitStack
from dataclasses import dataclass
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from starlette.websockets import WebSocketDisconnect, WebSocketState
from database.models import Conversation, Segment
from routers import websocket as ws_router
class _FakeWebSocket:
def __init__(self, messages, token="valid-token"):
self.query_params = {}
if token is not None:
self.query_params["token"] = token
self._messages = list(messages)
self.application_state = WebSocketState.CONNECTED
self.close_calls = []
async def receive_json(self):
if not self._messages:
raise WebSocketDisconnect()
next_item = self._messages.pop(0)
if isinstance(next_item, BaseException):
raise next_item
return next_item
async def close(self, code=None, reason=None):
self.close_calls.append({"code": code, "reason": reason})
self.application_state = WebSocketState.DISCONNECTED
@dataclass
class _ScalarsResult:
_items: list
def all(self):
return list(self._items)
@dataclass
class _ExecuteResult:
_items: list
def scalars(self):
return _ScalarsResult(self._items)
class _FakeAsyncDB:
def __init__(self, user, conversation=None, segments=None):
self.user = user
self.conversation = conversation
self.segments = list(segments or [])
self.added = []
self.commit_calls = 0
self.refresh_calls = 0
async def get(self, model, key):
model_name = getattr(model, "__name__", "")
if model_name == "User":
return self.user
if model_name == "Conversation":
if self.conversation and self.conversation.id == key:
return self.conversation
return None
return None
def add(self, obj):
self.added.append(obj)
if isinstance(obj, Conversation):
self.conversation = obj
if isinstance(obj, Segment):
self.segments.append(obj)
async def commit(self):
self.commit_calls += 1
async def refresh(self, obj):
_ = obj
self.refresh_calls += 1
async def execute(self, stmt):
_ = stmt
return _ExecuteResult(self.segments)
class _FakeManager:
def __init__(self):
self.active_connections = {}
self.sent_messages = []
self.disconnect_calls = []
self.background_runner = SimpleNamespace(
queue_message=AsyncMock(),
flush_pending=AsyncMock(),
)
self.conversation_agent = SimpleNamespace(
generate_profile_greeting=AsyncMock(return_value=[]),
)
async def connect(self, websocket, conversation_id):
self.active_connections[conversation_id] = websocket
async def disconnect(self, conversation_id):
self.disconnect_calls.append(conversation_id)
self.active_connections.pop(conversation_id, None)
async def send_message(self, conversation_id, message):
self.sent_messages.append({"conversation_id": conversation_id, "message": message})
def _make_user():
# Provide all profile fields to skip greeting/profile-collection branch.
return SimpleNamespace(
id="user-1",
nickname="tester",
subscription_type="premium",
birth_year=1990,
birth_place="A",
grew_up_place="B",
occupation="dev",
)
def _db_provider(db):
async def _provider():
yield db
return _provider
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
async def test_invalid_token_closes_connection(self):
websocket = _FakeWebSocket(messages=[], token="invalid")
with patch.object(ws_router, "verify_token", return_value=None):
await ws_router.websocket_endpoint(websocket, "conv-1")
self.assertEqual(len(websocket.close_calls), 1)
self.assertEqual(websocket.close_calls[0]["reason"], "无效的认证令牌")
async def test_text_message_creates_segment_and_dispatches_agent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{"type": "text", "data": {"text": "你好"}},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
self.assertEqual(len(segments), 1)
self.assertEqual(segments[0].transcript_text, "你好")
self.assertIsNone(segments[0].audio_url)
fake_manager.background_runner.queue_message.assert_awaited_once()
process_user_message_mock.assert_awaited_once()
message_types = [item["message"]["type"] for item in fake_manager.sent_messages]
self.assertIn(ws_router.MessageType.CONNECT, message_types)
async def test_audio_message_transcribes_then_calls_agent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_message",
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 12},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="这是转写结果")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=")
process_user_message_mock.assert_awaited_once()
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
self.assertEqual(len(segments), 1)
self.assertEqual(segments[0].transcript_text, "这是转写结果")
self.assertEqual(segments[0].audio_url, "audio:12s")
transcript_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
]
self.assertEqual(len(transcript_msgs), 1)
self.assertEqual(transcript_msgs[0]["data"]["text"], "这是转写结果")
async def test_transcribe_only_returns_transcript_without_segment_or_agent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "transcribe_only",
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ="},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="仅转写文本")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=")
process_user_message_mock.assert_not_awaited()
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0)
transcript_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
]
self.assertEqual(len(transcript_msgs), 1)
self.assertEqual(transcript_msgs[0]["data"]["text"], "仅转写文本")
async def test_end_conversation_updates_status_and_triggers_processing(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[{"type": "end_conversation", "conversation_id": "conv-1"}]
)
process_conversation_segments_mock = AsyncMock()
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch.object(
ws_router,
"process_conversation_segments",
process_conversation_segments_mock,
)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
self.assertEqual(conversation.status, "ended")
process_conversation_segments_mock.assert_awaited_once()
end_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.END_CONVERSATION
]
self.assertEqual(len(end_msgs), 1)
async def test_transcribe_only_missing_audio_returns_error(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{"type": "transcribe_only", "data": {}},
WebSocketDisconnect(),
]
)
transcribe_mock = AsyncMock(return_value="should-not-be-called")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
transcribe_mock.assert_not_awaited()
error_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.ERROR
]
self.assertEqual(len(error_msgs), 1)
self.assertEqual(error_msgs[0]["data"]["message"], "缺少 audio_base64")
async def test_audio_message_transcribe_failure_sends_error_and_skips_agent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_message",
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 8},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="转写失败: mock error")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
process_user_message_mock.assert_not_awaited()
error_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.ERROR
]
self.assertEqual(len(error_msgs), 1)
self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入")
if __name__ == "__main__":
unittest.main()