feat: 支持长语音分段上传与断线补传
This commit is contained in:
@@ -139,6 +139,20 @@ init_db()
|
||||
|
||||
### 本地开发
|
||||
|
||||
推荐使用一键脚本(会自动启动 PostgreSQL/Redis、检查 `.venv`、安装依赖并拉起 FastAPI + Celery):
|
||||
|
||||
```bash
|
||||
cd api
|
||||
./dev-up.sh
|
||||
```
|
||||
|
||||
可选环境变量:
|
||||
- `SKIP_INSTALL=1`:跳过依赖安装
|
||||
- `API_HOST` / `API_PORT`:覆盖 API 启动地址和端口
|
||||
- `CELERY_POOL`:覆盖 Celery 池类型(macOS 推荐 `solo`)
|
||||
|
||||
也可以使用手动方式:
|
||||
|
||||
```bash
|
||||
cd api
|
||||
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
WebSocket 路由:实时对话通信
|
||||
支持异步 Agent 调用和 Redis 会话存储
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
|
||||
from starlette.websockets import WebSocketState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -21,15 +23,16 @@ from database.models import User as UserModel
|
||||
from services.auth_service import verify_token
|
||||
from services.memoir_state_service import get_or_create_state
|
||||
from services import asr_service
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LEGACY_VOICE_SESSION_ID = "legacy"
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""WebSocket 消息类型"""
|
||||
CONNECT = "connect"
|
||||
AUDIO_CHUNK = "audio_chunk"
|
||||
AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传)
|
||||
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
|
||||
TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent,用于「转文字」发送
|
||||
TEXT = "text" # 文本消息
|
||||
@@ -47,6 +50,7 @@ class ConnectionManager:
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.segment_states: Dict[Tuple[str, str], "SegmentStreamState"] = {}
|
||||
# ConversationAgent 现在是无状态的(会话存储在 Redis),可以复用
|
||||
self.conversation_agent = ConversationAgent()
|
||||
self.memory_agent = MemoryAgent()
|
||||
@@ -61,8 +65,51 @@ class ConnectionManager:
|
||||
"""断开连接"""
|
||||
if conversation_id in self.active_connections:
|
||||
del self.active_connections[conversation_id]
|
||||
stale_keys = [
|
||||
key
|
||||
for key, state in self.segment_states.items()
|
||||
if key[0] == conversation_id and not state.active_tasks
|
||||
]
|
||||
for key in stale_keys:
|
||||
self.segment_states.pop(key, None)
|
||||
# 清除 Redis 中的会话记忆(可选,也可以保留用于恢复)
|
||||
# await self.conversation_agent.clear_memory(conversation_id)
|
||||
|
||||
def get_or_create_segment_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
) -> "SegmentStreamState":
|
||||
state_key = (conversation_id, voice_session_id)
|
||||
if state_key not in self.segment_states:
|
||||
self.segment_states[state_key] = SegmentStreamState()
|
||||
return self.segment_states[state_key]
|
||||
|
||||
def register_segment_task(
|
||||
self,
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
task: asyncio.Task,
|
||||
) -> None:
|
||||
state_key = (conversation_id, voice_session_id)
|
||||
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
|
||||
state.active_tasks.add(task)
|
||||
|
||||
def _cleanup(done_task: asyncio.Task) -> None:
|
||||
state.active_tasks.discard(done_task)
|
||||
if not state.active_tasks and conversation_id not in self.active_connections:
|
||||
self.segment_states.pop(state_key, None)
|
||||
if done_task.cancelled():
|
||||
return
|
||||
exc = done_task.exception()
|
||||
if exc:
|
||||
logger.error(
|
||||
"分段处理任务异常 "
|
||||
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
task.add_done_callback(_cleanup)
|
||||
|
||||
async def send_message(self, conversation_id: str, message: dict):
|
||||
"""发送消息"""
|
||||
@@ -88,6 +135,286 @@ class ConnectionManager:
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentStreamState:
|
||||
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
|
||||
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
pending_indices: Set[int] = field(default_factory=set)
|
||||
processed_indices: Set[int] = field(default_factory=set)
|
||||
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
|
||||
consumed_index: int = -1
|
||||
active_tasks: Set[asyncio.Task] = field(default_factory=set)
|
||||
|
||||
|
||||
def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
|
||||
if voice_session_id:
|
||||
return str(voice_session_id)
|
||||
return LEGACY_VOICE_SESSION_ID
|
||||
|
||||
|
||||
def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]:
|
||||
if not client_segment_id:
|
||||
return None
|
||||
session_id, separator, _ = client_segment_id.rpartition("-")
|
||||
if separator and session_id:
|
||||
return session_id
|
||||
return None
|
||||
|
||||
|
||||
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
|
||||
"""构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。"""
|
||||
return f"audio-segment:{voice_session_id}:{segment_index}"
|
||||
|
||||
|
||||
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
|
||||
"""从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}。"""
|
||||
prefix = "audio-segment:"
|
||||
if not audio_url or not audio_url.startswith(prefix):
|
||||
return None
|
||||
payload = audio_url[len(prefix):]
|
||||
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
|
||||
try:
|
||||
if separator:
|
||||
return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw))
|
||||
return (LEGACY_VOICE_SESSION_ID, int(payload))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
|
||||
if not transcript_text:
|
||||
return True
|
||||
return transcript_text.startswith("转写失败")
|
||||
|
||||
|
||||
async def _find_existing_segment_by_index(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
segment_index: int,
|
||||
) -> Optional[Segment]:
|
||||
"""
|
||||
按 conversation + voice_session_id + segment_index 查找已落库分段。
|
||||
说明:测试桩的 execute() 不会真正执行 where,所以这里做一次 Python 侧过滤,兼容真实 DB 和单测桩。
|
||||
"""
|
||||
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
|
||||
stmt = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id,
|
||||
Segment.audio_url == segment_audio_url,
|
||||
).order_by(Segment.created_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
candidates = result.scalars().all()
|
||||
for item in candidates:
|
||||
if item.conversation_id == conversation_id and item.audio_url == segment_audio_url:
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
async def _get_persisted_contiguous_segment_index(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
) -> int:
|
||||
"""读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。"""
|
||||
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
|
||||
result = await db.execute(stmt)
|
||||
candidates = result.scalars().all()
|
||||
|
||||
persisted_indices: Set[int] = set()
|
||||
for item in candidates:
|
||||
if item.conversation_id != conversation_id:
|
||||
continue
|
||||
segment_scope = _extract_segment_scope(item.audio_url)
|
||||
if not segment_scope:
|
||||
continue
|
||||
item_voice_session_id, item_index = segment_scope
|
||||
if item_voice_session_id != voice_session_id:
|
||||
continue
|
||||
persisted_indices.add(item_index)
|
||||
|
||||
contiguous_index = -1
|
||||
while contiguous_index + 1 in persisted_indices:
|
||||
contiguous_index += 1
|
||||
return contiguous_index
|
||||
|
||||
|
||||
async def _send_segment_transition_feedback(
|
||||
conversation_id: str,
|
||||
segment_index: int,
|
||||
manager: ConnectionManager,
|
||||
) -> None:
|
||||
"""ASR 处理中先给陪伴式过渡反馈,避免用户感知卡住。"""
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": "我在认真听,你继续说,我会边听边整理重点。",
|
||||
"transition": True,
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
|
||||
async def _process_audio_segment_async(
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
voice_session_id: str,
|
||||
segment_index: int,
|
||||
audio_base64: str,
|
||||
audio_duration: int,
|
||||
is_last: bool,
|
||||
manager: ConnectionManager,
|
||||
) -> None:
|
||||
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
|
||||
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
|
||||
|
||||
try:
|
||||
# 每个分段任务使用独立 DB Session,避免与主循环共享同一 AsyncSession 导致并发冲突。
|
||||
async for db in get_async_db():
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
user = await db.get(UserModel, user_id)
|
||||
if not conversation:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "对话不存在,分段处理已取消"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
return
|
||||
if not user:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "用户不存在,分段处理已取消"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
return
|
||||
|
||||
async with state.lock:
|
||||
should_prime_state = (
|
||||
state.consumed_index < 0
|
||||
and not state.processed_indices
|
||||
and not state.buffered_transcripts
|
||||
)
|
||||
|
||||
if should_prime_state:
|
||||
persisted_contiguous_index = await _get_persisted_contiguous_segment_index(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
voice_session_id=voice_session_id,
|
||||
)
|
||||
if persisted_contiguous_index >= 0:
|
||||
async with state.lock:
|
||||
state.consumed_index = max(state.consumed_index, persisted_contiguous_index)
|
||||
|
||||
transcript_text = await asr_service.transcribe(audio_base64)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": transcript_text or "",
|
||||
"audio_duration": audio_duration,
|
||||
"segment_index": segment_index,
|
||||
"is_last": is_last,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
if _is_transcribe_failure(transcript_text):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": f"分段 {segment_index} 转写失败,请重试该片段",
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
return
|
||||
|
||||
existing_segment = await _find_existing_segment_by_index(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
voice_session_id=voice_session_id,
|
||||
segment_index=segment_index,
|
||||
)
|
||||
if existing_segment:
|
||||
# 该分段已成功入库,视为重传:不重复入库、不重复触发 Agent。
|
||||
async with state.lock:
|
||||
state.processed_indices.add(segment_index)
|
||||
logger.info(
|
||||
"分段已存在,按幂等处理跳过: "
|
||||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||||
)
|
||||
return
|
||||
else:
|
||||
segment = Segment(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
transcript_text=transcript_text or "",
|
||||
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
|
||||
processed=False,
|
||||
)
|
||||
db.add(segment)
|
||||
await db.commit()
|
||||
await db.refresh(segment)
|
||||
await manager.background_runner.queue_message(conversation.user_id, segment.id)
|
||||
|
||||
ready_segments: List[Tuple[int, str, Segment]] = []
|
||||
async with state.lock:
|
||||
state.processed_indices.add(segment_index)
|
||||
state.buffered_transcripts[segment_index] = (transcript_text or "", segment)
|
||||
|
||||
next_index = state.consumed_index + 1
|
||||
while next_index in state.buffered_transcripts:
|
||||
text, seg = state.buffered_transcripts.pop(next_index)
|
||||
ready_segments.append((next_index, text, seg))
|
||||
state.consumed_index = next_index
|
||||
next_index += 1
|
||||
|
||||
# 仅当前缀分段连续时才触发 Agent,保证增量上下文顺序正确。
|
||||
for _, ordered_text, ordered_segment in ready_segments:
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=ordered_text,
|
||||
conversation=conversation,
|
||||
segment=ordered_segment,
|
||||
db=db,
|
||||
manager=manager,
|
||||
user=user,
|
||||
)
|
||||
|
||||
if is_last:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": "最后一段语音已收到,我会继续完善这一轮总结。",
|
||||
"transition": True,
|
||||
"is_last": True,
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
|
||||
exc_info=True,
|
||||
)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": f"分段处理失败: {str(e)}",
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
finally:
|
||||
async with state.lock:
|
||||
state.pending_indices.discard(segment_index)
|
||||
|
||||
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str
|
||||
@@ -229,8 +556,109 @@ async def websocket_endpoint(
|
||||
segment=segment,
|
||||
db=db,
|
||||
manager=manager,
|
||||
user=user,
|
||||
user=user,
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.AUDIO_SEGMENT:
|
||||
# 处理分段语音消息(长语音持续上传)
|
||||
data = message.get("data", {})
|
||||
audio_base64 = data.get("audio_base64", "")
|
||||
segment_index_raw = data.get("segment_index")
|
||||
voice_session_id = _normalize_voice_session_id(
|
||||
data.get("voice_session_id")
|
||||
or _voice_session_id_from_client_segment_id(data.get("client_segment_id"))
|
||||
)
|
||||
is_last = bool(data.get("is_last", False))
|
||||
audio_duration = int(data.get("duration", 0) or 0)
|
||||
|
||||
if not audio_base64:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 audio_base64"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
if segment_index_raw is None:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 segment_index"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
segment_index = int(segment_index_raw)
|
||||
except (TypeError, ValueError):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 必须为整数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
if segment_index < 0:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 不能为负数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
# 校验对话轮数配额(分段也计入对话轮次)
|
||||
from routers.quota import get_segment_count, check_can_send_message
|
||||
seg_count = await get_segment_count(user_id, db)
|
||||
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
|
||||
if not can_send:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
segment_state = manager.get_or_create_segment_state(
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
)
|
||||
should_process = False
|
||||
async with segment_state.lock:
|
||||
already_seen = (
|
||||
segment_index in segment_state.pending_indices
|
||||
or segment_index in segment_state.processed_indices
|
||||
or segment_index <= segment_state.consumed_index
|
||||
)
|
||||
if not already_seen:
|
||||
segment_state.pending_indices.add(segment_index)
|
||||
should_process = True
|
||||
|
||||
if not should_process:
|
||||
logger.info(
|
||||
"收到重复分段,跳过处理: "
|
||||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 先发过渡反馈,减少“等待空白”体感
|
||||
await _send_segment_transition_feedback(
|
||||
conversation_id=conversation_id,
|
||||
segment_index=segment_index,
|
||||
manager=manager,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
_process_audio_segment_async(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
voice_session_id=voice_session_id,
|
||||
segment_index=segment_index,
|
||||
audio_base64=audio_base64,
|
||||
audio_duration=audio_duration,
|
||||
is_last=is_last,
|
||||
manager=manager,
|
||||
)
|
||||
)
|
||||
manager.register_segment_task(conversation_id, voice_session_id, task)
|
||||
|
||||
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||||
# 处理完整音频消息(类似微信语音)
|
||||
@@ -615,4 +1043,3 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
|
||||
# 同时 flush 任何待处理的任务
|
||||
await manager.background_runner.flush_pending(conversation.user_id)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
@@ -89,6 +90,7 @@ class _FakeAsyncDB:
|
||||
class _FakeManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
self.segment_states = {}
|
||||
self.sent_messages = []
|
||||
self.disconnect_calls = []
|
||||
self.background_runner = SimpleNamespace(
|
||||
@@ -109,6 +111,21 @@ class _FakeManager:
|
||||
async def send_message(self, conversation_id, message):
|
||||
self.sent_messages.append({"conversation_id": conversation_id, "message": message})
|
||||
|
||||
def get_or_create_segment_state(self, conversation_id, voice_session_id):
|
||||
state_key = (conversation_id, voice_session_id)
|
||||
if state_key not in self.segment_states:
|
||||
self.segment_states[state_key] = ws_router.SegmentStreamState()
|
||||
return self.segment_states[state_key]
|
||||
|
||||
def register_segment_task(self, conversation_id, voice_session_id, task):
|
||||
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
|
||||
state.active_tasks.add(task)
|
||||
|
||||
def _cleanup(done_task):
|
||||
state.active_tasks.discard(done_task)
|
||||
|
||||
task.add_done_callback(_cleanup)
|
||||
|
||||
|
||||
def _make_user():
|
||||
# Provide all profile fields to skip greeting/profile-collection branch.
|
||||
@@ -435,6 +452,431 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(len(error_msgs), 1)
|
||||
self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入")
|
||||
|
||||
async def test_audio_segment_out_of_order_is_aggregated_by_segment_index(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "seg-1",
|
||||
"segment_index": 1,
|
||||
"duration": 12,
|
||||
"is_last": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "seg-0",
|
||||
"segment_index": 0,
|
||||
"duration": 10,
|
||||
"is_last": False,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(
|
||||
side_effect=lambda audio: {
|
||||
"seg-0": "这是第 0 段",
|
||||
"seg-1": "这是第 1 段",
|
||||
}[audio]
|
||||
)
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
self.assertEqual(transcribe_mock.await_count, 2)
|
||||
ordered_messages = [
|
||||
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
|
||||
]
|
||||
self.assertEqual(ordered_messages, ["这是第 0 段", "这是第 1 段"])
|
||||
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
|
||||
|
||||
async def test_audio_segment_duplicate_index_is_idempotent(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "dup-seg-0-a",
|
||||
"segment_index": 0,
|
||||
"duration": 10,
|
||||
"is_last": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "dup-seg-0-b",
|
||||
"segment_index": 0,
|
||||
"duration": 10,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="重复分段去重测试")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
self.assertEqual(transcribe_mock.await_count, 1)
|
||||
process_user_message_mock.assert_awaited_once()
|
||||
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1)
|
||||
|
||||
async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "session-a-seg-0",
|
||||
"voice_session_id": "voice-session-a",
|
||||
"client_segment_id": "voice-session-a-0",
|
||||
"segment_index": 0,
|
||||
"duration": 10,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "session-b-seg-0",
|
||||
"voice_session_id": "voice-session-b",
|
||||
"client_segment_id": "voice-session-b-0",
|
||||
"segment_index": 0,
|
||||
"duration": 8,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(
|
||||
side_effect=lambda audio: {
|
||||
"session-a-seg-0": "第一轮第 0 段",
|
||||
"session-b-seg-0": "第二轮第 0 段",
|
||||
}[audio]
|
||||
)
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
ordered_messages = [
|
||||
call.kwargs["user_message"] for call in process_user_message_mock.await_args_list
|
||||
]
|
||||
self.assertEqual(ordered_messages, ["第一轮第 0 段", "第二轮第 0 段"])
|
||||
self.assertEqual(transcribe_mock.await_count, 2)
|
||||
self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2)
|
||||
|
||||
async def test_audio_segment_sends_transition_feedback_while_processing(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
fake_db = _FakeAsyncDB(user=user, conversation=conversation)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "slow-seg-0",
|
||||
"segment_index": 0,
|
||||
"duration": 20,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
async def _slow_transcribe(_: str) -> str:
|
||||
await asyncio.sleep(0.2)
|
||||
return "慢速转写"
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(side_effect=_slow_transcribe)
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
transition_msgs = [
|
||||
item["message"]
|
||||
for item in fake_manager.sent_messages
|
||||
if item["message"]["type"] == ws_router.MessageType.AGENT_RESPONSE
|
||||
and item["message"].get("data", {}).get("transition") is True
|
||||
]
|
||||
self.assertGreaterEqual(len(transition_msgs), 1)
|
||||
|
||||
async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
existing_segment = Segment(
|
||||
id="seg-existing-0",
|
||||
conversation_id="conv-1",
|
||||
transcript_text="已存在的上一段",
|
||||
audio_url="audio-segment:voice-session-1:0",
|
||||
processed=False,
|
||||
)
|
||||
fake_db = _FakeAsyncDB(
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
segments=[existing_segment],
|
||||
)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "seg-1-after-reconnect",
|
||||
"voice_session_id": "voice-session-1",
|
||||
"client_segment_id": "voice-session-1-1",
|
||||
"segment_index": 1,
|
||||
"duration": 18,
|
||||
"is_last": True,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="这是重连后的第 1 段")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
process_user_message_mock.assert_awaited_once()
|
||||
self.assertEqual(
|
||||
process_user_message_mock.await_args.kwargs["user_message"],
|
||||
"这是重连后的第 1 段",
|
||||
)
|
||||
|
||||
async def test_audio_segment_reconnect_uses_contiguous_prefix_not_max_index(self):
|
||||
user = _make_user()
|
||||
conversation = Conversation(id="conv-1", user_id=user.id, status="active")
|
||||
existing_segments = [
|
||||
Segment(
|
||||
id="seg-existing-0",
|
||||
conversation_id="conv-1",
|
||||
transcript_text="已存在的第 0 段",
|
||||
audio_url="audio-segment:voice-session-gap:0",
|
||||
processed=False,
|
||||
),
|
||||
Segment(
|
||||
id="seg-existing-2",
|
||||
conversation_id="conv-1",
|
||||
transcript_text="已存在的第 2 段",
|
||||
audio_url="audio-segment:voice-session-gap:2",
|
||||
processed=False,
|
||||
),
|
||||
]
|
||||
fake_db = _FakeAsyncDB(
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
segments=existing_segments,
|
||||
)
|
||||
fake_manager = _FakeManager()
|
||||
fake_websocket = _FakeWebSocket(
|
||||
messages=[
|
||||
{
|
||||
"type": "audio_segment",
|
||||
"data": {
|
||||
"audio_base64": "seg-1-gap-retry",
|
||||
"voice_session_id": "voice-session-gap",
|
||||
"client_segment_id": "voice-session-gap-1",
|
||||
"segment_index": 1,
|
||||
"duration": 18,
|
||||
"is_last": False,
|
||||
},
|
||||
},
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
)
|
||||
|
||||
process_user_message_mock = AsyncMock()
|
||||
transcribe_mock = AsyncMock(return_value="补传后的第 1 段")
|
||||
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
ws_router,
|
||||
"verify_token",
|
||||
return_value={"type": "access", "sub": user.id},
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "get_async_db", _db_provider(fake_db))
|
||||
)
|
||||
stack.enter_context(patch.object(ws_router, "manager", fake_manager))
|
||||
stack.enter_context(
|
||||
patch("routers.quota.get_segment_count", new=AsyncMock(return_value=0))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("routers.quota.check_can_send_message", return_value=(True, ""))
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router, "process_user_message", process_user_message_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch.object(ws_router.asr_service, "transcribe", transcribe_mock)
|
||||
)
|
||||
|
||||
await ws_router.websocket_endpoint(fake_websocket, "conv-1")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
process_user_message_mock.assert_awaited_once()
|
||||
self.assertEqual(
|
||||
process_user_message_mock.await_args.kwargs["user_message"],
|
||||
"补传后的第 1 段",
|
||||
)
|
||||
self.assertEqual(transcribe_mock.await_count, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user