* fix/ 0:00 audio ui * fix/ persist memoir image state and collapse voice history Keep generated chapter images from staying in processing after successful uploads, and restore segmented voice recordings as a single audio message when reopening conversations. Made-with: Cursor * fix/ persist local conversation state and stabilize voice UI Keep CreateMemory conversations driven by Room so recent text and audio survive page exits, and prevent stale 0:00 voice bubbles while list ordering follows the latest local message time. Made-with: Cursor * fix/ server-side root cause for conversation list time and message timestamps - Add Conversation.last_message_at column with migration and index - Update last_message_at on text message, audio segment, and AI response - Sort conversation list by COALESCE(last_message_at, started_at) DESC - Return real per-message timestamps from Redis history instead of now() - Pass user_message_timestamp through agent pipeline to avoid LLM delay skew - Remove all debug logging from server, client, and CI workflow - Restore import json in conversation_agent (was broken by debug removal) - Client: remove DebugRuntimeLogger, stop sending transcript as text message Made-with: Cursor --------- Co-authored-by: Kevin <kevin@brighteng.org>
981 lines
37 KiB
Python
981 lines
37 KiB
Python
import asyncio
|
|
import unittest
|
|
from contextlib import ExitStack
|
|
from dataclasses import dataclass
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from starlette.websockets import WebSocketDisconnect, WebSocketState
|
|
|
|
from database.models import Conversation, Segment
|
|
from routers import websocket as ws_router
|
|
|
|
|
|
class _FakeWebSocket:
|
|
def __init__(self, messages, token="valid-token"):
|
|
self.query_params = {}
|
|
if token is not None:
|
|
self.query_params["token"] = token
|
|
self._messages = list(messages)
|
|
self.application_state = WebSocketState.CONNECTED
|
|
self.close_calls = []
|
|
|
|
async def receive_json(self):
|
|
if not self._messages:
|
|
raise WebSocketDisconnect()
|
|
next_item = self._messages.pop(0)
|
|
if isinstance(next_item, BaseException):
|
|
raise next_item
|
|
return next_item
|
|
|
|
async def close(self, code=None, reason=None):
|
|
self.close_calls.append({"code": code, "reason": reason})
|
|
self.application_state = WebSocketState.DISCONNECTED
|
|
|
|
|
|
@dataclass
|
|
class _ScalarsResult:
|
|
_items: list
|
|
|
|
def all(self):
|
|
return list(self._items)
|
|
|
|
|
|
@dataclass
|
|
class _ExecuteResult:
|
|
_items: list
|
|
|
|
def scalars(self):
|
|
return _ScalarsResult(self._items)
|
|
|
|
|
|
class _FakeAsyncDB:
|
|
def __init__(self, user, conversation=None, segments=None):
|
|
self.user = user
|
|
self.conversation = conversation
|
|
self.segments = list(segments or [])
|
|
self.added = []
|
|
self.commit_calls = 0
|
|
self.refresh_calls = 0
|
|
|
|
async def get(self, model, key):
|
|
model_name = getattr(model, "__name__", "")
|
|
if model_name == "User":
|
|
return self.user
|
|
if model_name == "Conversation":
|
|
if self.conversation and self.conversation.id == key:
|
|
return self.conversation
|
|
return None
|
|
return None
|
|
|
|
def add(self, obj):
|
|
self.added.append(obj)
|
|
if isinstance(obj, Conversation):
|
|
self.conversation = obj
|
|
if isinstance(obj, Segment):
|
|
self.segments.append(obj)
|
|
|
|
async def commit(self):
|
|
self.commit_calls += 1
|
|
|
|
async def refresh(self, obj):
|
|
_ = obj
|
|
self.refresh_calls += 1
|
|
|
|
async def execute(self, stmt):
|
|
_ = stmt
|
|
return _ExecuteResult(self.segments)
|
|
|
|
|
|
class _FakeManager:
|
|
def __init__(self):
|
|
self.active_connections = {}
|
|
self.segment_states = {}
|
|
self.sent_messages = []
|
|
self.disconnect_calls = []
|
|
self.background_runner = SimpleNamespace(
|
|
queue_message=AsyncMock(),
|
|
flush_pending=AsyncMock(),
|
|
)
|
|
self.conversation_agent = SimpleNamespace(
|
|
generate_profile_greeting=AsyncMock(return_value=[]),
|
|
generate_opening_message=AsyncMock(return_value=[]),
|
|
)
|
|
|
|
async def connect(self, websocket, conversation_id):
|
|
self.active_connections[conversation_id] = websocket
|
|
|
|
async def disconnect(self, conversation_id):
|
|
self.disconnect_calls.append(conversation_id)
|
|
self.active_connections.pop(conversation_id, None)
|
|
|
|
async def send_message(self, conversation_id, message):
|
|
self.sent_messages.append({"conversation_id": conversation_id, "message": message})
|
|
|
|
def get_or_create_segment_state(self, conversation_id, voice_session_id):
|
|
state_key = (conversation_id, voice_session_id)
|
|
if state_key not in self.segment_states:
|
|
self.segment_states[state_key] = ws_router.SegmentStreamState()
|
|
return self.segment_states[state_key]
|
|
|
|
def register_segment_task(self, conversation_id, voice_session_id, task):
|
|
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
|
|
state.active_tasks.add(task)
|
|
|
|
def _cleanup(done_task):
|
|
state.active_tasks.discard(done_task)
|
|
|
|
task.add_done_callback(_cleanup)
|
|
|
|
|
|
def _make_user():
|
|
# Provide all profile fields to skip greeting/profile-collection branch.
|
|
return SimpleNamespace(
|
|
id="user-1",
|
|
nickname="tester",
|
|
subscription_type="premium",
|
|
birth_year=1990,
|
|
birth_place="A",
|
|
grew_up_place="B",
|
|
occupation="dev",
|
|
)
|
|
|
|
|
|
def _db_provider(db):
|
|
async def _provider():
|
|
yield db
|
|
|
|
return _provider
|
|
|
|
|
|
def _redis_empty_history_patch():
|
|
"""Patch redis to return empty history so websocket sends opening (or skips if mocked)."""
|
|
return patch.object(
|
|
ws_router.redis_service,
|
|
"get_conversation_history",
|
|
new=AsyncMock(return_value=[]),
|
|
)
|
|
|
|
|
|
class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
|
async def test_invalid_token_closes_connection(self):
|
|
websocket = _FakeWebSocket(messages=[], token="invalid")
|
|
|
|
with patch.object(ws_router, "verify_token", return_value=None):
|
|
await ws_router.websocket_endpoint(websocket, "conv-1")
|
|
|
|
self.assertEqual(len(websocket.close_calls), 1)
|
|
self.assertEqual(websocket.close_calls[0]["reason"], "无效的认证令牌")
|
|
|
|
async def test_text_message_creates_segment_and_dispatches_agent(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{"type": "text", "data": {"text": "你好"}},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
|
|
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
|
|
self.assertEqual(len(segments), 1)
|
|
self.assertEqual(segments[0].transcript_text, "你好")
|
|
self.assertIsNone(segments[0].audio_url)
|
|
self.assertIsNotNone(conversation.last_message_at)
|
|
fake_manager.background_runner.queue_message.assert_awaited_once()
|
|
process_user_message_mock.assert_awaited_once()
|
|
|
|
message_types = [item["message"]["type"] for item in fake_manager.sent_messages]
|
|
self.assertIn(ws_router.MessageType.CONNECT, message_types)
|
|
|
|
async def test_audio_message_transcribes_then_calls_agent(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_message",
|
|
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 12},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(return_value="这是转写结果")
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
|
|
transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=")
|
|
process_user_message_mock.assert_awaited_once()
|
|
|
|
segments = [obj for obj in fake_db.added if isinstance(obj, Segment)]
|
|
self.assertEqual(len(segments), 1)
|
|
self.assertEqual(segments[0].transcript_text, "这是转写结果")
|
|
self.assertEqual(segments[0].audio_url, "audio:12s")
|
|
self.assertIsNotNone(conversation.last_message_at)
|
|
|
|
transcript_msgs = [
|
|
item["message"]
|
|
for item in fake_manager.sent_messages
|
|
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
|
]
|
|
self.assertEqual(len(transcript_msgs), 1)
|
|
self.assertEqual(transcript_msgs[0]["data"]["text"], "这是转写结果")
|
|
|
|
async def test_transcribe_only_returns_transcript_without_segment_or_agent(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "transcribe_only",
|
|
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ="},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(return_value="仅转写文本")
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
|
|
transcribe_mock.assert_awaited_once_with("ZmFrZS1hdWRpby1iNjQ=")
|
|
process_user_message_mock.assert_not_awaited()
|
|
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0)
|
|
|
|
transcript_msgs = [
|
|
item["message"]
|
|
for item in fake_manager.sent_messages
|
|
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
|
]
|
|
self.assertEqual(len(transcript_msgs), 1)
|
|
self.assertEqual(transcript_msgs[0]["data"]["text"], "仅转写文本")
|
|
|
|
async def test_end_conversation_updates_status_and_triggers_processing(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[{"type": "end_conversation", "conversation_id": "conv-1"}]
|
|
)
|
|
process_conversation_segments_mock = AsyncMock()
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"process_conversation_segments",
|
|
process_conversation_segments_mock,
|
|
)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
|
|
self.assertEqual(conversation.status, "ended")
|
|
process_conversation_segments_mock.assert_awaited_once()
|
|
end_msgs = [
|
|
item["message"]
|
|
for item in fake_manager.sent_messages
|
|
if item["message"]["type"] == ws_router.MessageType.END_CONVERSATION
|
|
]
|
|
self.assertEqual(len(end_msgs), 1)
|
|
|
|
async def test_transcribe_only_missing_audio_returns_error(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{"type": "transcribe_only", "data": {}},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
transcribe_mock = AsyncMock(return_value="should-not-be-called")
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
|
|
transcribe_mock.assert_not_awaited()
|
|
error_msgs = [
|
|
item["message"]
|
|
for item in fake_manager.sent_messages
|
|
if item["message"]["type"] == ws_router.MessageType.ERROR
|
|
]
|
|
self.assertEqual(len(error_msgs), 1)
|
|
self.assertEqual(error_msgs[0]["data"]["message"], "缺少 audio_base64")
|
|
|
|
async def test_audio_message_transcribe_failure_sends_error_and_skips_agent(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_message",
|
|
"data": {"audio_base64": "ZmFrZS1hdWRpby1iNjQ=", "duration": 8},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(return_value="转写失败: mock error")
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
|
|
process_user_message_mock.assert_not_awaited()
|
|
error_msgs = [
|
|
item["message"]
|
|
for item in fake_manager.sent_messages
|
|
if item["message"]["type"] == ws_router.MessageType.ERROR
|
|
]
|
|
self.assertEqual(len(error_msgs), 1)
|
|
self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入")
|
|
|
|
async def test_audio_segment_out_of_order_is_aggregated_by_segment_index(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "seg-1",
|
|
"voice_session_id": "voice-session-1",
|
|
"segment_index": 1,
|
|
"duration": 12,
|
|
"is_last": False,
|
|
},
|
|
},
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "seg-0",
|
|
"voice_session_id": "voice-session-1",
|
|
"segment_index": 0,
|
|
"duration": 10,
|
|
"is_last": False,
|
|
},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(
|
|
side_effect=lambda audio: {
|
|
"seg-0": "这是第 0 段",
|
|
"seg-1": "这是第 1 段",
|
|
}[audio]
|
|
)
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
await asyncio.sleep(0.05)
|
|
|
|
self.assertEqual(transcribe_mock.await_count, 2)
|
|
ordered_messages = [
|
|
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
|
|
]
|
|
self.assertEqual(ordered_messages, ["这是第 0 段", "这是第 1 段"])
|
|
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
|
|
transcript_msgs = [
|
|
item["message"]
|
|
for item in fake_manager.sent_messages
|
|
if item["message"]["type"] == ws_router.MessageType.TRANSCRIPT
|
|
]
|
|
self.assertEqual(
|
|
[msg["data"]["voice_session_id"] for msg in transcript_msgs],
|
|
["voice-session-1", "voice-session-1"],
|
|
)
|
|
|
|
async def test_audio_segment_duplicate_index_is_idempotent(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "dup-seg-0-a",
|
|
"segment_index": 0,
|
|
"duration": 10,
|
|
"is_last": False,
|
|
},
|
|
},
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "dup-seg-0-b",
|
|
"segment_index": 0,
|
|
"duration": 10,
|
|
"is_last": True,
|
|
},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(return_value="重复分段去重测试")
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
await asyncio.sleep(0.05)
|
|
|
|
self.assertEqual(transcribe_mock.await_count, 1)
|
|
process_user_message_mock.assert_awaited_once()
|
|
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1)
|
|
|
|
async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "session-a-seg-0",
|
|
"voice_session_id": "voice-session-a",
|
|
"client_segment_id": "voice-session-a-0",
|
|
"segment_index": 0,
|
|
"duration": 10,
|
|
"is_last": True,
|
|
},
|
|
},
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "session-b-seg-0",
|
|
"voice_session_id": "voice-session-b",
|
|
"client_segment_id": "voice-session-b-0",
|
|
"segment_index": 0,
|
|
"duration": 8,
|
|
"is_last": True,
|
|
},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(
|
|
side_effect=lambda audio: {
|
|
"session-a-seg-0": "第一轮第 0 段",
|
|
"session-b-seg-0": "第二轮第 0 段",
|
|
}[audio]
|
|
)
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
await asyncio.sleep(0.05)
|
|
|
|
ordered_messages = [
|
|
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
|
|
]
|
|
self.assertEqual(ordered_messages, ["第一轮第 0 段", "第二轮第 0 段"])
|
|
self.assertEqual(transcribe_mock.await_count, 2)
|
|
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
|
|
|
|
async def test_audio_segment_sends_transition_feedback_while_processing(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "slow-seg-0",
|
|
"segment_index": 0,
|
|
"duration": 20,
|
|
"is_last": True,
|
|
},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
async def _slow_transcribe(_: str) -> str:
|
|
await asyncio.sleep(0.2)
|
|
return "慢速转写"
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(side_effect=_slow_transcribe)
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
await asyncio.sleep(0.05)
|
|
|
|
transition_msgs = [
|
|
item["message"]
|
|
for item in fake_manager.sent_messages
|
|
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
|
and item["message"].get("data", {}).get("transition") is True
|
|
]
|
|
self.assertGreaterEqual(len(transition_msgs), 1)
|
|
|
|
async def test_audio_segment_last_segment_does_not_emit_terminal_transition(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "last-seg-0",
|
|
"voice_session_id": "voice-session-last",
|
|
"client_segment_id": "voice-session-last-0",
|
|
"segment_index": 0,
|
|
"duration": 15,
|
|
"is_last": True,
|
|
},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(return_value="最后一段转写")
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
await asyncio.sleep(0.05)
|
|
|
|
transition_msgs = [
|
|
item["message"]
|
|
for item in fake_manager.sent_messages
|
|
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
|
and item["message"].get("data", {}).get("transition") is True
|
|
]
|
|
self.assertEqual(len(transition_msgs), 1)
|
|
self.assertIsNone(transition_msgs[0]["data"].get("is_last"))
|
|
|
|
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
existing_segment = Segment(
|
|
id="seg-existing-0",
|
|
conversation_id="conv-1",
|
|
transcript_text="已存在的上一段",
|
|
audio_url="audio-segment:voice-session-1:0",
|
|
processed=False,
|
|
)
|
|
fake_db = _FakeAsyncDB(
|
|
user=user,
|
|
conversation=conversation,
|
|
segments=[existing_segment],
|
|
)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "seg-1-after-reconnect",
|
|
"voice_session_id": "voice-session-1",
|
|
"client_segment_id": "voice-session-1-1",
|
|
"segment_index": 1,
|
|
"duration": 18,
|
|
"is_last": True,
|
|
},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(return_value="这是重连后的第 1 段")
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
await asyncio.sleep(0.05)
|
|
|
|
process_user_message_mock.assert_awaited_once()
|
|
self.assertEqual(
|
|
process_user_message_mock.await_args.kwargs["user_message"],
|
|
"这是重连后的第 1 段",
|
|
)
|
|
|
|
async def test_audio_segment_reconnect_uses_contiguous_prefix_not_max_index(self):
|
|
user = _make_user()
|
|
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
|
existing_segments = [
|
|
Segment(
|
|
id="seg-existing-0",
|
|
conversation_id="conv-1",
|
|
transcript_text="已存在的第 0 段",
|
|
audio_url="audio-segment:voice-session-gap:0",
|
|
processed=False,
|
|
),
|
|
Segment(
|
|
id="seg-existing-2",
|
|
conversation_id="conv-1",
|
|
transcript_text="已存在的第 2 段",
|
|
audio_url="audio-segment:voice-session-gap:2",
|
|
processed=False,
|
|
),
|
|
]
|
|
fake_db = _FakeAsyncDB(
|
|
user=user,
|
|
conversation=conversation,
|
|
segments=existing_segments,
|
|
)
|
|
fake_manager = _FakeManager()
|
|
fake_websocket = _FakeWebSocket(
|
|
messages=[
|
|
{
|
|
"type": "audio_segment",
|
|
"data": {
|
|
"audio_base64": "seg-1-gap-retry",
|
|
"voice_session_id": "voice-session-gap",
|
|
"client_segment_id": "voice-session-gap-1",
|
|
"segment_index": 1,
|
|
"duration": 18,
|
|
"is_last": False,
|
|
},
|
|
},
|
|
WebSocketDisconnect(),
|
|
]
|
|
)
|
|
|
|
process_user_message_mock = AsyncMock()
|
|
transcribe_mock = AsyncMock(return_value="补传后的第 1 段")
|
|
|
|
with ExitStack() as stack:
|
|
stack.enter_context(
|
|
patch.object(
|
|
ws_router,
|
|
"verify_token",
|
|
return_value={"type": "access", "sub": user.id},
|
|
)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
|
)
|
|
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
|
stack.enter_context(_redis_empty_history_patch())
|
|
stack.enter_context(
|
|
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
|
)
|
|
stack.enter_context(
|
|
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
|
)
|
|
stack.enter_context(
|
|
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
|
)
|
|
|
|
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
|
await asyncio.sleep(0.05)
|
|
|
|
process_user_message_mock.assert_awaited_once()
|
|
self.assertEqual(
|
|
process_user_message_mock.await_args.kwargs["user_message"],
|
|
"补传后的第 1 段",
|
|
)
|
|
self.assertEqual(transcribe_mock.await_count, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|