feat: 支持长语音分段上传与断线补传
This commit is contained in:
@@ -2,13 +2,15 @@
|
||||
WebSocket 路由:实时对话通信
|
||||
支持异步 Agent 调用和 Redis 会话存储
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status
|
||||
from starlette.websockets import WebSocketState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -21,15 +23,16 @@ from database.models import User as UserModel
|
||||
from services.auth_service import verify_token
|
||||
from services.memoir_state_service import get_or_create_state
|
||||
from services import asr_service
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LEGACY_VOICE_SESSION_ID = "legacy"
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""WebSocket 消息类型"""
|
||||
CONNECT = "connect"
|
||||
AUDIO_CHUNK = "audio_chunk"
|
||||
AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传)
|
||||
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
|
||||
TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent,用于「转文字」发送
|
||||
TEXT = "text" # 文本消息
|
||||
@@ -47,6 +50,7 @@ class ConnectionManager:
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.segment_states: Dict[Tuple[str, str], "SegmentStreamState"] = {}
|
||||
# ConversationAgent 现在是无状态的(会话存储在 Redis),可以复用
|
||||
self.conversation_agent = ConversationAgent()
|
||||
self.memory_agent = MemoryAgent()
|
||||
@@ -61,8 +65,51 @@ class ConnectionManager:
|
||||
"""断开连接"""
|
||||
if conversation_id in self.active_connections:
|
||||
del self.active_connections[conversation_id]
|
||||
stale_keys = [
|
||||
key
|
||||
for key, state in self.segment_states.items()
|
||||
if key[0] == conversation_id and not state.active_tasks
|
||||
]
|
||||
for key in stale_keys:
|
||||
self.segment_states.pop(key, None)
|
||||
# 清除 Redis 中的会话记忆(可选,也可以保留用于恢复)
|
||||
# await self.conversation_agent.clear_memory(conversation_id)
|
||||
|
||||
def get_or_create_segment_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
) -> "SegmentStreamState":
|
||||
state_key = (conversation_id, voice_session_id)
|
||||
if state_key not in self.segment_states:
|
||||
self.segment_states[state_key] = SegmentStreamState()
|
||||
return self.segment_states[state_key]
|
||||
|
||||
def register_segment_task(
|
||||
self,
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
task: asyncio.Task,
|
||||
) -> None:
|
||||
state_key = (conversation_id, voice_session_id)
|
||||
state = self.get_or_create_segment_state(conversation_id, voice_session_id)
|
||||
state.active_tasks.add(task)
|
||||
|
||||
def _cleanup(done_task: asyncio.Task) -> None:
|
||||
state.active_tasks.discard(done_task)
|
||||
if not state.active_tasks and conversation_id not in self.active_connections:
|
||||
self.segment_states.pop(state_key, None)
|
||||
if done_task.cancelled():
|
||||
return
|
||||
exc = done_task.exception()
|
||||
if exc:
|
||||
logger.error(
|
||||
"分段处理任务异常 "
|
||||
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
task.add_done_callback(_cleanup)
|
||||
|
||||
async def send_message(self, conversation_id: str, message: dict):
|
||||
"""发送消息"""
|
||||
@@ -88,6 +135,286 @@ class ConnectionManager:
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentStreamState:
|
||||
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
|
||||
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
pending_indices: Set[int] = field(default_factory=set)
|
||||
processed_indices: Set[int] = field(default_factory=set)
|
||||
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
|
||||
consumed_index: int = -1
|
||||
active_tasks: Set[asyncio.Task] = field(default_factory=set)
|
||||
|
||||
|
||||
def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
|
||||
if voice_session_id:
|
||||
return str(voice_session_id)
|
||||
return LEGACY_VOICE_SESSION_ID
|
||||
|
||||
|
||||
def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]:
|
||||
if not client_segment_id:
|
||||
return None
|
||||
session_id, separator, _ = client_segment_id.rpartition("-")
|
||||
if separator and session_id:
|
||||
return session_id
|
||||
return None
|
||||
|
||||
|
||||
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
|
||||
"""构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。"""
|
||||
return f"audio-segment:{voice_session_id}:{segment_index}"
|
||||
|
||||
|
||||
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
|
||||
"""从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}。"""
|
||||
prefix = "audio-segment:"
|
||||
if not audio_url or not audio_url.startswith(prefix):
|
||||
return None
|
||||
payload = audio_url[len(prefix):]
|
||||
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
|
||||
try:
|
||||
if separator:
|
||||
return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw))
|
||||
return (LEGACY_VOICE_SESSION_ID, int(payload))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
|
||||
if not transcript_text:
|
||||
return True
|
||||
return transcript_text.startswith("转写失败")
|
||||
|
||||
|
||||
async def _find_existing_segment_by_index(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
segment_index: int,
|
||||
) -> Optional[Segment]:
|
||||
"""
|
||||
按 conversation + voice_session_id + segment_index 查找已落库分段。
|
||||
说明:测试桩的 execute() 不会真正执行 where,所以这里做一次 Python 侧过滤,兼容真实 DB 和单测桩。
|
||||
"""
|
||||
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
|
||||
stmt = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id,
|
||||
Segment.audio_url == segment_audio_url,
|
||||
).order_by(Segment.created_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
candidates = result.scalars().all()
|
||||
for item in candidates:
|
||||
if item.conversation_id == conversation_id and item.audio_url == segment_audio_url:
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
async def _get_persisted_contiguous_segment_index(
|
||||
db: AsyncSession,
|
||||
conversation_id: str,
|
||||
voice_session_id: str,
|
||||
) -> int:
|
||||
"""读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。"""
|
||||
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
|
||||
result = await db.execute(stmt)
|
||||
candidates = result.scalars().all()
|
||||
|
||||
persisted_indices: Set[int] = set()
|
||||
for item in candidates:
|
||||
if item.conversation_id != conversation_id:
|
||||
continue
|
||||
segment_scope = _extract_segment_scope(item.audio_url)
|
||||
if not segment_scope:
|
||||
continue
|
||||
item_voice_session_id, item_index = segment_scope
|
||||
if item_voice_session_id != voice_session_id:
|
||||
continue
|
||||
persisted_indices.add(item_index)
|
||||
|
||||
contiguous_index = -1
|
||||
while contiguous_index + 1 in persisted_indices:
|
||||
contiguous_index += 1
|
||||
return contiguous_index
|
||||
|
||||
|
||||
async def _send_segment_transition_feedback(
|
||||
conversation_id: str,
|
||||
segment_index: int,
|
||||
manager: ConnectionManager,
|
||||
) -> None:
|
||||
"""ASR 处理中先给陪伴式过渡反馈,避免用户感知卡住。"""
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": "我在认真听,你继续说,我会边听边整理重点。",
|
||||
"transition": True,
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
|
||||
async def _process_audio_segment_async(
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
voice_session_id: str,
|
||||
segment_index: int,
|
||||
audio_base64: str,
|
||||
audio_duration: int,
|
||||
is_last: bool,
|
||||
manager: ConnectionManager,
|
||||
) -> None:
|
||||
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
|
||||
state = manager.get_or_create_segment_state(conversation_id, voice_session_id)
|
||||
|
||||
try:
|
||||
# 每个分段任务使用独立 DB Session,避免与主循环共享同一 AsyncSession 导致并发冲突。
|
||||
async for db in get_async_db():
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
user = await db.get(UserModel, user_id)
|
||||
if not conversation:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "对话不存在,分段处理已取消"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
return
|
||||
if not user:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "用户不存在,分段处理已取消"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
return
|
||||
|
||||
async with state.lock:
|
||||
should_prime_state = (
|
||||
state.consumed_index < 0
|
||||
and not state.processed_indices
|
||||
and not state.buffered_transcripts
|
||||
)
|
||||
|
||||
if should_prime_state:
|
||||
persisted_contiguous_index = await _get_persisted_contiguous_segment_index(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
voice_session_id=voice_session_id,
|
||||
)
|
||||
if persisted_contiguous_index >= 0:
|
||||
async with state.lock:
|
||||
state.consumed_index = max(state.consumed_index, persisted_contiguous_index)
|
||||
|
||||
transcript_text = await asr_service.transcribe(audio_base64)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": transcript_text or "",
|
||||
"audio_duration": audio_duration,
|
||||
"segment_index": segment_index,
|
||||
"is_last": is_last,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
if _is_transcribe_failure(transcript_text):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": f"分段 {segment_index} 转写失败,请重试该片段",
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
return
|
||||
|
||||
existing_segment = await _find_existing_segment_by_index(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
voice_session_id=voice_session_id,
|
||||
segment_index=segment_index,
|
||||
)
|
||||
if existing_segment:
|
||||
# 该分段已成功入库,视为重传:不重复入库、不重复触发 Agent。
|
||||
async with state.lock:
|
||||
state.processed_indices.add(segment_index)
|
||||
logger.info(
|
||||
"分段已存在,按幂等处理跳过: "
|
||||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||||
)
|
||||
return
|
||||
else:
|
||||
segment = Segment(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
transcript_text=transcript_text or "",
|
||||
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
|
||||
processed=False,
|
||||
)
|
||||
db.add(segment)
|
||||
await db.commit()
|
||||
await db.refresh(segment)
|
||||
await manager.background_runner.queue_message(conversation.user_id, segment.id)
|
||||
|
||||
ready_segments: List[Tuple[int, str, Segment]] = []
|
||||
async with state.lock:
|
||||
state.processed_indices.add(segment_index)
|
||||
state.buffered_transcripts[segment_index] = (transcript_text or "", segment)
|
||||
|
||||
next_index = state.consumed_index + 1
|
||||
while next_index in state.buffered_transcripts:
|
||||
text, seg = state.buffered_transcripts.pop(next_index)
|
||||
ready_segments.append((next_index, text, seg))
|
||||
state.consumed_index = next_index
|
||||
next_index += 1
|
||||
|
||||
# 仅当前缀分段连续时才触发 Agent,保证增量上下文顺序正确。
|
||||
for _, ordered_text, ordered_segment in ready_segments:
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=ordered_text,
|
||||
conversation=conversation,
|
||||
segment=ordered_segment,
|
||||
db=db,
|
||||
manager=manager,
|
||||
user=user,
|
||||
)
|
||||
|
||||
if is_last:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": "最后一段语音已收到,我会继续完善这一轮总结。",
|
||||
"transition": True,
|
||||
"is_last": True,
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
|
||||
exc_info=True,
|
||||
)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": f"分段处理失败: {str(e)}",
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
finally:
|
||||
async with state.lock:
|
||||
state.pending_indices.discard(segment_index)
|
||||
|
||||
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str
|
||||
@@ -229,8 +556,109 @@ async def websocket_endpoint(
|
||||
segment=segment,
|
||||
db=db,
|
||||
manager=manager,
|
||||
user=user,
|
||||
user=user,
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.AUDIO_SEGMENT:
|
||||
# 处理分段语音消息(长语音持续上传)
|
||||
data = message.get("data", {})
|
||||
audio_base64 = data.get("audio_base64", "")
|
||||
segment_index_raw = data.get("segment_index")
|
||||
voice_session_id = _normalize_voice_session_id(
|
||||
data.get("voice_session_id")
|
||||
or _voice_session_id_from_client_segment_id(data.get("client_segment_id"))
|
||||
)
|
||||
is_last = bool(data.get("is_last", False))
|
||||
audio_duration = int(data.get("duration", 0) or 0)
|
||||
|
||||
if not audio_base64:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 audio_base64"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
if segment_index_raw is None:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 segment_index"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
segment_index = int(segment_index_raw)
|
||||
except (TypeError, ValueError):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 必须为整数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
if segment_index < 0:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 不能为负数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
# 校验对话轮数配额(分段也计入对话轮次)
|
||||
from routers.quota import get_segment_count, check_can_send_message
|
||||
seg_count = await get_segment_count(user_id, db)
|
||||
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
|
||||
if not can_send:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
continue
|
||||
|
||||
segment_state = manager.get_or_create_segment_state(
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
)
|
||||
should_process = False
|
||||
async with segment_state.lock:
|
||||
already_seen = (
|
||||
segment_index in segment_state.pending_indices
|
||||
or segment_index in segment_state.processed_indices
|
||||
or segment_index <= segment_state.consumed_index
|
||||
)
|
||||
if not already_seen:
|
||||
segment_state.pending_indices.add(segment_index)
|
||||
should_process = True
|
||||
|
||||
if not should_process:
|
||||
logger.info(
|
||||
"收到重复分段,跳过处理: "
|
||||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 先发过渡反馈,减少“等待空白”体感
|
||||
await _send_segment_transition_feedback(
|
||||
conversation_id=conversation_id,
|
||||
segment_index=segment_index,
|
||||
manager=manager,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(
|
||||
_process_audio_segment_async(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
voice_session_id=voice_session_id,
|
||||
segment_index=segment_index,
|
||||
audio_base64=audio_base64,
|
||||
audio_duration=audio_duration,
|
||||
is_last=is_last,
|
||||
manager=manager,
|
||||
)
|
||||
)
|
||||
manager.register_segment_task(conversation_id, voice_session_id, task)
|
||||
|
||||
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||||
# 处理完整音频消息(类似微信语音)
|
||||
@@ -615,4 +1043,3 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
|
||||
# 同时 flush 任何待处理的任务
|
||||
await manager.background_runner.flush_pending(conversation.user_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user