feat: 支持长语音分段上传与断线补传

This commit is contained in:
Kevin
2026-03-09 15:30:18 +08:00
parent 440f5be07f
commit 6ffe96d7a9
13 changed files with 1451 additions and 19 deletions

View File

@@ -139,6 +139,20 @@ init_db()
### 本地开发
推荐使用一键脚本(会自动启动 PostgreSQL/Redis、检查 `.venv`、安装依赖并拉起 FastAPI + Celery
```bash
cd api
./dev-up.sh
```
可选环境变量:
- `SKIP_INSTALL=1`:跳过依赖安装
- `API_HOST` / `API_PORT`:覆盖 API 启动地址和端口
- `CELERY_POOL`:覆盖 Celery 池类型macOS 推荐 `solo`
也可以使用手动方式:
```bash
cd api

View File

@@ -2,13 +2,15 @@
WebSocket 路由:实时对话通信
支持异步 Agent 调用和 Redis 会话存储
"""
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Dict
from typing import Dict, List, Optional, Set, Tuple
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
from starlette.websockets import WebSocketState
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -21,15 +23,16 @@ from database.models import User as UserModel
from services.auth_service import verify_token
from services.memoir_state_service import get_or_create_state
from services import asr_service
from fastapi import HTTPException, status
logger = logging.getLogger(__name__)
LEGACY_VOICE_SESSION_ID = "legacy"
class MessageType(str, Enum):
"""WebSocket 消息类型"""
CONNECT = "connect"
AUDIO_CHUNK = "audio_chunk"
AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传)
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent用于「转文字」发送
TEXT = "text" # 文本消息
@@ -47,6 +50,7 @@ class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.segment_states: Dict[Tuple[str, str], "SegmentStreamState"] = {}
# ConversationAgent 现在是无状态的(会话存储在 Redis可以复用
self.conversation_agent = ConversationAgent()
self.memory_agent = MemoryAgent()
@@ -61,8 +65,51 @@ class ConnectionManager:
"""断开连接"""
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
stale_keys = [
key
for key, state in self.segment_states.items()
if key[0] == conversation_id and not state.active_tasks
]
for key in stale_keys:
self.segment_states.pop(key, None)
# 清除 Redis 中的会话记忆(可选,也可以保留用于恢复)
# await self.conversation_agent.clear_memory(conversation_id)
def get_or_create_segment_state(
self,
conversation_id: str,
voice_session_id: str,
) -> "SegmentStreamState":
state_key = (conversation_id, voice_session_id)
if state_key not in self.segment_states:
self.segment_states[state_key] = SegmentStreamState()
return self.segment_states[state_key]
def register_segment_task(
self,
conversation_id: str,
voice_session_id: str,
task: asyncio.Task,
) -> None:
state_key = (conversation_id, voice_session_id)
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
state.active_tasks.add(task)
def _cleanup(done_task: asyncio.Task) -> None:
state.active_tasks.discard(done_task)
if not state.active_tasks and conversation_id not in self.active_connections:
self.segment_states.pop(state_key, None)
if done_task.cancelled():
return
exc = done_task.exception()
if exc:
logger.error(
"分段处理任务异常 "
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
exc_info=True,
)
task.add_done_callback(_cleanup)
async def send_message(self, conversation_id: str, message: dict):
"""发送消息"""
@@ -88,6 +135,286 @@ class ConnectionManager:
manager = ConnectionManager()
@dataclass
class SegmentStreamState:
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
pending_indices: Set[int] = field(default_factory=set)
processed_indices: Set[int] = field(default_factory=set)
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
consumed_index: int = -1
active_tasks: Set[asyncio.Task] = field(default_factory=set)
def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
if voice_session_id:
return str(voice_session_id)
return LEGACY_VOICE_SESSION_ID
def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]:
if not client_segment_id:
return None
session_id, separator, _ = client_segment_id.rpartition("-")
if separator and session_id:
return session_id
return None
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
"""构建分段语音的幂等标识conversation_id + voice_session_id + segment_index"""
return f"audio-segment:{voice_session_id}:{segment_index}"
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
"""从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}"""
prefix = "audio-segment:"
if not audio_url or not audio_url.startswith(prefix):
return None
payload = audio_url[len(prefix):]
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
try:
if separator:
return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw))
return (LEGACY_VOICE_SESSION_ID, int(payload))
except ValueError:
return None
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
if not transcript_text:
return True
return transcript_text.startswith("转写失败")
async def _find_existing_segment_by_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
segment_index: int,
) -> Optional[Segment]:
"""
按 conversation + voice_session_id + segment_index 查找已落库分段。
说明:测试桩的 execute() 不会真正执行 where所以这里做一次 Python 侧过滤,兼容真实 DB 和单测桩。
"""
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.audio_url == segment_audio_url,
).order_by(Segment.created_at.desc())
result = await db.execute(stmt)
candidates = result.scalars().all()
for item in candidates:
if item.conversation_id == conversation_id and item.audio_url == segment_audio_url:
return item
return None
async def _get_persisted_contiguous_segment_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
) -> int:
"""读取数据库中当前 voice session 已连续落库的最大 segment_index用于重连恢复。"""
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
result = await db.execute(stmt)
candidates = result.scalars().all()
persisted_indices: Set[int] = set()
for item in candidates:
if item.conversation_id != conversation_id:
continue
segment_scope = _extract_segment_scope(item.audio_url)
if not segment_scope:
continue
item_voice_session_id, item_index = segment_scope
if item_voice_session_id != voice_session_id:
continue
persisted_indices.add(item_index)
contiguous_index = -1
while contiguous_index + 1 in persisted_indices:
contiguous_index += 1
return contiguous_index
async def _send_segment_transition_feedback(
conversation_id: str,
segment_index: int,
manager: ConnectionManager,
) -> None:
"""ASR 处理中先给陪伴式过渡反馈,避免用户感知卡住。"""
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": "我在认真听,你继续说,我会边听边整理重点。",
"transition": True,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
async def _process_audio_segment_async(
conversation_id: str,
user_id: str,
voice_session_id: str,
segment_index: int,
audio_base64: str,
audio_duration: int,
is_last: bool,
manager: ConnectionManager,
) -> None:
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
try:
# 每个分段任务使用独立 DB Session避免与主循环共享同一 AsyncSession 导致并发冲突。
async for db in get_async_db():
conversation = await db.get(Conversation, conversation_id)
user = await db.get(UserModel, user_id)
if not conversation:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "对话不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
if not user:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "用户不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
async with state.lock:
should_prime_state = (
state.consumed_index < 0
and not state.processed_indices
and not state.buffered_transcripts
)
if should_prime_state:
persisted_contiguous_index = await _get_persisted_contiguous_segment_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
if persisted_contiguous_index >= 0:
async with state.lock:
state.consumed_index = max(state.consumed_index, persisted_contiguous_index)
transcript_text = await asr_service.transcribe(audio_base64)
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text or "",
"audio_duration": audio_duration,
"segment_index": segment_index,
"is_last": is_last,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
if _is_transcribe_failure(transcript_text):
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {
"message": f"分段 {segment_index} 转写失败,请重试该片段",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
existing_segment = await _find_existing_segment_by_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
)
if existing_segment:
# 该分段已成功入库,视为重传:不重复入库、不重复触发 Agent。
async with state.lock:
state.processed_indices.add(segment_index)
logger.info(
"分段已存在,按幂等处理跳过: "
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
)
return
else:
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=transcript_text or "",
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
processed=False,
)
db.add(segment)
await db.commit()
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
ready_segments: List[Tuple[int, str, Segment]] = []
async with state.lock:
state.processed_indices.add(segment_index)
state.buffered_transcripts[segment_index] = (transcript_text or "", segment)
next_index = state.consumed_index + 1
while next_index in state.buffered_transcripts:
text, seg = state.buffered_transcripts.pop(next_index)
ready_segments.append((next_index, text, seg))
state.consumed_index = next_index
next_index += 1
# 仅当前缀分段连续时才触发 Agent保证增量上下文顺序正确。
for _, ordered_text, ordered_segment in ready_segments:
await process_user_message(
conversation_id=conversation_id,
user_message=ordered_text,
conversation=conversation,
segment=ordered_segment,
db=db,
manager=manager,
user=user,
)
if is_last:
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": "最后一段语音已收到,我会继续完善这一轮总结。",
"transition": True,
"is_last": True,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
break
except Exception as e:
logger.error(
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
exc_info=True,
)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {
"message": f"分段处理失败: {str(e)}",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
finally:
async with state.lock:
state.pending_indices.discard(segment_index)
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str
@@ -229,8 +556,109 @@ async def websocket_endpoint(
segment=segment,
db=db,
manager=manager,
user=user,
user=user,
)
elif msg_type == MessageType.AUDIO_SEGMENT:
# 处理分段语音消息(长语音持续上传)
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
segment_index_raw = data.get("segment_index")
voice_session_id = _normalize_voice_session_id(
data.get("voice_session_id")
or _voice_session_id_from_client_segment_id(data.get("client_segment_id"))
)
is_last = bool(data.get("is_last", False))
audio_duration = int(data.get("duration", 0) or 0)
if not audio_base64:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
if segment_index_raw is None:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 segment_index"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
try:
segment_index = int(segment_index_raw)
except (TypeError, ValueError):
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "segment_index 必须为整数"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
if segment_index < 0:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "segment_index 不能为负数"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
# 校验对话轮数配额(分段也计入对话轮次)
from routers.quota import get_segment_count, check_can_send_message
seg_count = await get_segment_count(user_id, db)
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
segment_state = manager.get_or_create_segment_state(
conversation_id,
voice_session_id,
)
should_process = False
async with segment_state.lock:
already_seen = (
segment_index in segment_state.pending_indices
or segment_index in segment_state.processed_indices
or segment_index <= segment_state.consumed_index
)
if not already_seen:
segment_state.pending_indices.add(segment_index)
should_process = True
if not should_process:
logger.info(
"收到重复分段,跳过处理: "
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
)
continue
# 先发过渡反馈,减少“等待空白”体感
await _send_segment_transition_feedback(
conversation_id=conversation_id,
segment_index=segment_index,
manager=manager,
)
task = asyncio.create_task(
_process_audio_segment_async(
conversation_id=conversation_id,
user_id=user_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
audio_base64=audio_base64,
audio_duration=audio_duration,
is_last=is_last,
manager=manager,
)
)
manager.register_segment_task(conversation_id, voice_session_id, task)
elif msg_type == MessageType.AUDIO_MESSAGE:
# 处理完整音频消息(类似微信语音)
@@ -615,4 +1043,3 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
# 同时 flush 任何待处理的任务
await manager.background_runner.flush_pending(conversation.user_id)

View File

@@ -1,3 +1,4 @@
import asyncio
import unittest
from contextlib import ExitStack
from dataclasses import dataclass
@@ -89,6 +90,7 @@ class _FakeAsyncDB:
class _FakeManager:
def __init__(self):
self.active_connections = {}
self.segment_states = {}
self.sent_messages = []
self.disconnect_calls = []
self.background_runner = SimpleNamespace(
@@ -109,6 +111,21 @@ class _FakeManager:
async def send_message(self, conversation_id, message):
self.sent_messages.append({"conversation_id": conversation_id, "message": message})
def get_or_create_segment_state(self, conversation_id, voice_session_id):
state_key = (conversation_id, voice_session_id)
if state_key not in self.segment_states:
self.segment_states[state_key] = ws_router.SegmentStreamState()
return self.segment_states[state_key]
def register_segment_task(self, conversation_id, voice_session_id, task):
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
state.active_tasks.add(task)
def _cleanup(done_task):
state.active_tasks.discard(done_task)
task.add_done_callback(_cleanup)
def _make_user():
# Provide all profile fields to skip greeting/profile-collection branch.
@@ -435,6 +452,431 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(error_msgs), 1)
self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入")
async def test_audio_segment_out_of_order_is_aggregated_by_segment_index(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1",
"segment_index": 1,
"duration": 12,
"is_last": False,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-0",
"segment_index": 0,
"duration": 10,
"is_last": False,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(
side_effect=lambda audio: {
"seg-0": "这是第 0 段",
"seg-1": "这是第 1 段",
}[audio]
)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
self.assertEqual(transcribe_mock.await_count, 2)
ordered_messages = [
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
]
self.assertEqual(ordered_messages, ["这是第 0 段", "这是第 1 段"])
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
async def test_audio_segment_duplicate_index_is_idempotent(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "dup-seg-0-a",
"segment_index": 0,
"duration": 10,
"is_last": False,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "dup-seg-0-b",
"segment_index": 0,
"duration": 10,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="重复分段去重测试")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
self.assertEqual(transcribe_mock.await_count, 1)
process_user_message_mock.assert_awaited_once()
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1)
async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "session-a-seg-0",
"voice_session_id": "voice-session-a",
"client_segment_id": "voice-session-a-0",
"segment_index": 0,
"duration": 10,
"is_last": True,
},
},
{
"type": "audio_segment",
"data": {
"audio_base64": "session-b-seg-0",
"voice_session_id": "voice-session-b",
"client_segment_id": "voice-session-b-0",
"segment_index": 0,
"duration": 8,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(
side_effect=lambda audio: {
"session-a-seg-0": "第一轮第 0 段",
"session-b-seg-0": "第二轮第 0 段",
}[audio]
)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
ordered_messages = [
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
]
self.assertEqual(ordered_messages, ["第一轮第 0 段", "第二轮第 0 段"])
self.assertEqual(transcribe_mock.await_count, 2)
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
async def test_audio_segment_sends_transition_feedback_while_processing(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "slow-seg-0",
"segment_index": 0,
"duration": 20,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
async def _slow_transcribe(_: str) -> str:
await asyncio.sleep(0.2)
return "慢速转写"
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(side_effect=_slow_transcribe)
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
transition_msgs = [
item["message"]
for item in fake_manager.sent_messages
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
and item["message"].get("data", {}).get("transition") is True
]
self.assertGreaterEqual(len(transition_msgs), 1)
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
existing_segment = Segment(
id="seg-existing-0",
conversation_id="conv-1",
transcript_text="已存在的上一段",
audio_url="audio-segment:voice-session-1:0",
processed=False,
)
fake_db = _FakeAsyncDB(
user=user,
conversation=conversation,
segments=[existing_segment],
)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1-after-reconnect",
"voice_session_id": "voice-session-1",
"client_segment_id": "voice-session-1-1",
"segment_index": 1,
"duration": 18,
"is_last": True,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="这是重连后的第 1 段")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
process_user_message_mock.assert_awaited_once()
self.assertEqual(
process_user_message_mock.await_args.kwargs["user_message"],
"这是重连后的第 1 段",
)
async def test_audio_segment_reconnect_uses_contiguous_prefix_not_max_index(self):
user = _make_user()
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
existing_segments = [
Segment(
id="seg-existing-0",
conversation_id="conv-1",
transcript_text="已存在的第 0 段",
audio_url="audio-segment:voice-session-gap:0",
processed=False,
),
Segment(
id="seg-existing-2",
conversation_id="conv-1",
transcript_text="已存在的第 2 段",
audio_url="audio-segment:voice-session-gap:2",
processed=False,
),
]
fake_db = _FakeAsyncDB(
user=user,
conversation=conversation,
segments=existing_segments,
)
fake_manager = _FakeManager()
fake_websocket = _FakeWebSocket(
messages=[
{
"type": "audio_segment",
"data": {
"audio_base64": "seg-1-gap-retry",
"voice_session_id": "voice-session-gap",
"client_segment_id": "voice-session-gap-1",
"segment_index": 1,
"duration": 18,
"is_last": False,
},
},
WebSocketDisconnect(),
]
)
process_user_message_mock = AsyncMock()
transcribe_mock = AsyncMock(return_value="补传后的第 1 段")
with ExitStack() as stack:
stack.enter_context(
patch.object(
ws_router,
"verify_token",
return_value={"type": "access", "sub": user.id},
)
)
stack.enter_context(
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
)
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
stack.enter_context(
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
)
stack.enter_context(
patch("routers.quota.check_can_send_message", return_value=(True, ""))
)
stack.enter_context(
patch.object(ws_router, "process_user_message", process_user_message_mock)
)
stack.enter_context(
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
)
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
await asyncio.sleep(0.05)
process_user_message_mock.assert_awaited_once()
self.assertEqual(
process_user_message_mock.await_args.kwargs["user_message"],
"补传后的第 1 段",
)
self.assertEqual(transcribe_mock.await_count, 1)
if __name__ == "__main__":
unittest.main()