441 lines
16 KiB
Python
441 lines
16 KiB
Python
|
|
import unittest
|
||
|
|
from contextlib import ExitStack
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from types import SimpleNamespace
|
||
|
|
from unittest.mock import AsyncMock, patch
|
||
|
|
|
||
|
|
from starlette.websockets import WebSocketDisconnect, WebSocketState
|
||
|
|
|
||
|
|
from database.models import Conversation, Segment
|
||
|
|
from routers import websocket as ws_router
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeWebSocket:
|
||
|
|
def __init__(self, messages, token="valid-token"):
|
||
|
|
self.query_params = {}
|
||
|
|
if token is not None:
|
||
|
|
self.query_params["token"] = token
|
||
|
|
self._messages = list(messages)
|
||
|
|
self.application_state = WebSocketState.CONNECTED
|
||
|
|
self.close_calls = []
|
||
|
|
|
||
|
|
async def receive_json(self):
|
||
|
|
if not self._messages:
|
||
|
|
raise WebSocketDisconnect()
|
||
|
|
next_item = self._messages.pop(0)
|
||
|
|
if isinstance(next_item, BaseException):
|
||
|
|
raise next_item
|
||
|
|
return next_item
|
||
|
|
|
||
|
|
async def close(self, code=None, reason=None):
|
||
|
|
self.close_calls.append({"code": code, "reason": reason})
|
||
|
|
self.application_state = WebSocketState.DISCONNECTED
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class _ScalarsResult:
|
||
|
|
_items: list
|
||
|
|
|
||
|
|
def all(self):
|
||
|
|
return list(self._items)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class _ExecuteResult:
|
||
|
|
_items: list
|
||
|
|
|
||
|
|
def scalars(self):
|
||
|
|
return _ScalarsResult(self._items)
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeAsyncDB:
|
||
|
|
def __init__(self, user, conversation=None, segments=None):
|
||
|
|
self.user = user
|
||
|
|
self.conversation = conversation
|
||
|
|
self.segments = list(segments or [])
|
||
|
|
self.added = []
|
||
|
|
self.commit_calls = 0
|
||
|
|
self.refresh_calls = 0
|
||
|
|
|
||
|
|
async def get(self, model, key):
|
||
|
|
model_name = getattr(model, "__name__", "")
|
||
|
|
if model_name == "User":
|
||
|
|
return self.user
|
||
|
|
if model_name == "Conversation":
|
||
|
|
if self.conversation and self.conversation.id == key:
|
||
|
|
return self.conversation
|
||
|
|
return None
|
||
|
|
return None
|
||
|
|
|
||
|
|
def add(self, obj):
|
||
|
|
self.added.append(obj)
|
||
|
|
if isinstance(obj, Conversation):
|
||
|
|
self.conversation = obj
|
||
|
|
if isinstance(obj, Segment):
|
||
|
|
self.segments.append(obj)
|
||
|
|
|
||
|
|
async def commit(self):
|
||
|
|
self.commit_calls += 1
|
||
|
|
|
||
|
|
async def refresh(self, obj):
|
||
|
|
_ = obj
|
||
|
|
self.refresh_calls += 1
|
||
|
|
|
||
|
|
async def execute(self, stmt):
|
||
|
|
_ = stmt
|
||
|
|
return _ExecuteResult(self.segments)
|
||
|
|
|
||
|
|
|
||
|
|
class _FakeManager:
|
||
|
|
def __init__(self):
|
||
|
|
self.active_connections = {}
|
||
|
|
self.sent_messages = []
|
||
|
|
self.disconnect_calls = []
|
||
|
|
self.background_runner = SimpleNamespace(
|
||
|
|
queue_message=AsyncMock(),
|
||
|
|
flush_pending=AsyncMock(),
|
||
|
|
)
|
||
|
|
self.conversation_agent = SimpleNamespace(
|
||
|
|
generate_profile_greeting=AsyncMock(return_value=[]),
|
||
|
|
)
|
||
|
|
|
||
|
|
async def connect(self, websocket, conversation_id):
|
||
|
|
self.active_connections[conversation_id] = websocket
|
||
|
|
|
||
|
|
async def disconnect(self, conversation_id):
|
||
|
|
self.disconnect_calls.append(conversation_id)
|
||
|
|
self.active_connections.pop(conversation_id, None)
|
||
|
|
|
||
|
|
async def send_message(self, conversation_id, message):
|
||
|
|
self.sent_messages.append({"conversation_id": conversation_id, "message": message})
|
||
|
|
|
||
|
|
|
||
|
|
def _make_user():
|
||
|
|
# Provide all profile fields to skip greeting/profile-collection branch.
|
||
|
|
return SimpleNamespace(
|
||
|
|
id="user-1",
|
||
|
|
nickname="tester",
|
||
|
|
subscription_type="premium",
|
||
|
|
birth_year=1990,
|
||
|
|
birth_place="A",
|
||
|
|
grew_up_place="B",
|
||
|
|
occupation="dev",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _db_provider(db):
|
||
|
|
async def _provider():
|
||
|
|
yield db
|
||
|
|
|
||
|
|
return _provider
|
||
|
|
|
||
|
|
|
||
|
|
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||
|
|
async def test_invalid_token_closes_connection(self):
|
||
|
|
websocket = _FakeWebSocket(messages=[], token="invalid")
|
||
|
|
|
||
|
|
with patch.object(ws_router, "verify_token", return_value=None):
|
||
|
|
await ws_router.websocket_endpoint(websocket, "conv-1")
|
||
|
|
|
||
|
|
self.assertEqual(len(websocket.close_calls), 1)
|
||
|
|
self.assertEqual(websocket.close_calls[0]["reason"], "无效的认证令牌")
|
||
|
|
|
||
|
|
async def test_text_message_creates_segment_and_dispatches_agent(self):
|
||
|
|
user = _make_user()
|
||
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
|
|
fake_manager = _FakeManager()
|
||
|
|
fake_websocket = _FakeWebSocket(
|
||
|
|
messages=[
|
||
|
|
{"type": "text", "data": {"text": "你好"}},
|
||
|
|
WebSocketDisconnect(),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
process_user_message_mock = AsyncMock()
|
||
|
|
|
||
|
|
with ExitStack() as stack:
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(
|
||
|
|
ws_router,
|
||
|
|
"verify_token",
|
||
|
|
return_value={"type": "access", "sub": user.id},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||
|
|
)
|
||
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
|
|
stack.enter_context(
|
||
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||
|
|
)
|
||
|
|
|
||
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
|
|
||
|
|
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
|
||
|
|
self.assertEqual(len(segments), 1)
|
||
|
|
self.assertEqual(segments[0].transcript_text, "你好")
|
||
|
|
self.assertIsNone(segments[0].audio_url)
|
||
|
|
fake_manager.background_runner.queue_message.assert_awaited_once()
|
||
|
|
process_user_message_mock.assert_awaited_once()
|
||
|
|
|
||
|
|
message_types = [item["message"]["type"] for item in fake_manager.sent_messages]
|
||
|
|
self.assertIn(ws_router.MessageType.CONNECT, message_types)
|
||
|
|
|
||
|
|
async def test_audio_message_transcribes_then_calls_agent(self):
|
||
|
|
user = _make_user()
|
||
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
|
|
fake_manager = _FakeManager()
|
||
|
|
fake_websocket = _FakeWebSocket(
|
||
|
|
messages=[
|
||
|
|
{
|
||
|
|
"type": "audio_message",
|
||
|
|
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 12},
|
||
|
|
},
|
||
|
|
WebSocketDisconnect(),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
process_user_message_mock = AsyncMock()
|
||
|
|
transcribe_mock = AsyncMock(return_value="这是转写结果")
|
||
|
|
|
||
|
|
with ExitStack() as stack:
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(
|
||
|
|
ws_router,
|
||
|
|
"verify_token",
|
||
|
|
return_value={"type": "access", "sub": user.id},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||
|
|
)
|
||
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
|
|
stack.enter_context(
|
||
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||
|
|
)
|
||
|
|
|
||
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
|
|
||
|
|
transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=")
|
||
|
|
process_user_message_mock.assert_awaited_once()
|
||
|
|
|
||
|
|
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
|
||
|
|
self.assertEqual(len(segments), 1)
|
||
|
|
self.assertEqual(segments[0].transcript_text, "这是转写结果")
|
||
|
|
self.assertEqual(segments[0].audio_url, "audio:12s")
|
||
|
|
|
||
|
|
transcript_msgs = [
|
||
|
|
item["message"]
|
||
|
|
for item in fake_manager.sent_messages
|
||
|
|
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
||
|
|
]
|
||
|
|
self.assertEqual(len(transcript_msgs), 1)
|
||
|
|
self.assertEqual(transcript_msgs[0]["data"]["text"], "这是转写结果")
|
||
|
|
|
||
|
|
async def test_transcribe_only_returns_transcript_without_segment_or_agent(self):
|
||
|
|
user = _make_user()
|
||
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
|
|
fake_manager = _FakeManager()
|
||
|
|
fake_websocket = _FakeWebSocket(
|
||
|
|
messages=[
|
||
|
|
{
|
||
|
|
"type": "transcribe_only",
|
||
|
|
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ="},
|
||
|
|
},
|
||
|
|
WebSocketDisconnect(),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
process_user_message_mock = AsyncMock()
|
||
|
|
transcribe_mock = AsyncMock(return_value="仅转写文本")
|
||
|
|
|
||
|
|
with ExitStack() as stack:
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(
|
||
|
|
ws_router,
|
||
|
|
"verify_token",
|
||
|
|
return_value={"type": "access", "sub": user.id},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||
|
|
)
|
||
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||
|
|
)
|
||
|
|
|
||
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
|
|
||
|
|
transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=")
|
||
|
|
process_user_message_mock.assert_not_awaited()
|
||
|
|
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0)
|
||
|
|
|
||
|
|
transcript_msgs = [
|
||
|
|
item["message"]
|
||
|
|
for item in fake_manager.sent_messages
|
||
|
|
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
||
|
|
]
|
||
|
|
self.assertEqual(len(transcript_msgs), 1)
|
||
|
|
self.assertEqual(transcript_msgs[0]["data"]["text"], "仅转写文本")
|
||
|
|
|
||
|
|
async def test_end_conversation_updates_status_and_triggers_processing(self):
|
||
|
|
user = _make_user()
|
||
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
|
|
fake_manager = _FakeManager()
|
||
|
|
fake_websocket = _FakeWebSocket(
|
||
|
|
messages=[{"type": "end_conversation", "conversation_id": "conv-1"}]
|
||
|
|
)
|
||
|
|
process_conversation_segments_mock = AsyncMock()
|
||
|
|
|
||
|
|
with ExitStack() as stack:
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(
|
||
|
|
ws_router,
|
||
|
|
"verify_token",
|
||
|
|
return_value={"type": "access", "sub": user.id},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||
|
|
)
|
||
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(
|
||
|
|
ws_router,
|
||
|
|
"process_conversation_segments",
|
||
|
|
process_conversation_segments_mock,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
|
|
||
|
|
self.assertEqual(conversation.status, "ended")
|
||
|
|
process_conversation_segments_mock.assert_awaited_once()
|
||
|
|
end_msgs = [
|
||
|
|
item["message"]
|
||
|
|
for item in fake_manager.sent_messages
|
||
|
|
if item["message"]["type"] == ws_router.MessageType.END_CONVERSATION
|
||
|
|
]
|
||
|
|
self.assertEqual(len(end_msgs), 1)
|
||
|
|
|
||
|
|
async def test_transcribe_only_missing_audio_returns_error(self):
|
||
|
|
user = _make_user()
|
||
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
|
|
fake_manager = _FakeManager()
|
||
|
|
fake_websocket = _FakeWebSocket(
|
||
|
|
messages=[
|
||
|
|
{"type": "transcribe_only", "data": {}},
|
||
|
|
WebSocketDisconnect(),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
transcribe_mock = AsyncMock(return_value="should-not-be-called")
|
||
|
|
|
||
|
|
with ExitStack() as stack:
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(
|
||
|
|
ws_router,
|
||
|
|
"verify_token",
|
||
|
|
return_value={"type": "access", "sub": user.id},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||
|
|
)
|
||
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||
|
|
)
|
||
|
|
|
||
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
|
|
||
|
|
transcribe_mock.assert_not_awaited()
|
||
|
|
error_msgs = [
|
||
|
|
item["message"]
|
||
|
|
for item in fake_manager.sent_messages
|
||
|
|
if item["message"]["type"] == ws_router.MessageType.ERROR
|
||
|
|
]
|
||
|
|
self.assertEqual(len(error_msgs), 1)
|
||
|
|
self.assertEqual(error_msgs[0]["data"]["message"], "缺少 audio_base64")
|
||
|
|
|
||
|
|
async def test_audio_message_transcribe_failure_sends_error_and_skips_agent(self):
|
||
|
|
user = _make_user()
|
||
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||
|
|
fake_manager = _FakeManager()
|
||
|
|
fake_websocket = _FakeWebSocket(
|
||
|
|
messages=[
|
||
|
|
{
|
||
|
|
"type": "audio_message",
|
||
|
|
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 8},
|
||
|
|
},
|
||
|
|
WebSocketDisconnect(),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
process_user_message_mock = AsyncMock()
|
||
|
|
transcribe_mock = AsyncMock(return_value="转写失败: mock error")
|
||
|
|
|
||
|
|
with ExitStack() as stack:
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(
|
||
|
|
ws_router,
|
||
|
|
"verify_token",
|
||
|
|
return_value={"type": "access", "sub": user.id},
|
||
|
|
)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||
|
|
)
|
||
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||
|
|
stack.enter_context(
|
||
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||
|
|
)
|
||
|
|
stack.enter_context(
|
||
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||
|
|
)
|
||
|
|
|
||
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||
|
|
|
||
|
|
process_user_message_mock.assert_not_awaited()
|
||
|
|
error_msgs = [
|
||
|
|
item["message"]
|
||
|
|
for item in fake_manager.sent_messages
|
||
|
|
if item["message"]["type"] == ws_router.MessageType.ERROR
|
||
|
|
]
|
||
|
|
self.assertEqual(len(error_msgs), 1)
|
||
|
|
self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
unittest.main()
|