223 lines
7.8 KiB
Python
223 lines
7.8 KiB
Python
"""语音桌面终端:assignment 状态、WebSocket 推送与 HTTP 拉取兜底。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
from asyncio import Lock
|
||
from collections import defaultdict
|
||
from collections.abc import Awaitable, Callable
|
||
from typing import Any
|
||
|
||
from fastapi import WebSocket
|
||
from loguru import logger
|
||
from starlette.websockets import WebSocketDisconnect
|
||
|
||
from app.config import Settings
|
||
from app.services.voice_terminal_binding import VoiceTerminalBindingIndex
|
||
|
||
PendingHeadFetcher = Callable[[str], Awaitable[Any]]
|
||
|
||
|
||
async def assign_voice_terminal_after_recording_started(
|
||
hub: VoiceTerminalHub,
|
||
*,
|
||
surgery_id: str,
|
||
camera_ids: list[str],
|
||
set_voice_terminal_id: Callable[[str, str | None], None],
|
||
) -> None:
|
||
"""开录成功后:按站点绑定解析终端、写入会话、并 WebSocket 推送 start(与 HTTP 开录一致)。"""
|
||
voice_tid = hub.resolve_terminal(list(camera_ids))
|
||
if voice_tid:
|
||
set_voice_terminal_id(surgery_id, voice_tid)
|
||
await hub.notify_start(voice_tid, surgery_id)
|
||
elif hub.bindings is not None:
|
||
logger.warning(
|
||
"开录未向任何语音终端推送:camera_ids 与 OR_SITE_CONFIG「voice_or_room_bindings」无匹配 "
|
||
"surgery_id={} camera_ids={}",
|
||
surgery_id,
|
||
camera_ids,
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"开录未推送语音终端:未加载 OR_SITE_CONFIG 或 voice_or_room_bindings 为空;"
|
||
"桌面端 WebSocket 不会收到 voice_assignment surgery_id={}",
|
||
surgery_id,
|
||
)
|
||
|
||
|
||
class VoiceTerminalHub:
|
||
"""进程内终端连接与当前手术分配(多 worker 需另行同步)。"""
|
||
|
||
def __init__(
|
||
self,
|
||
settings: Settings,
|
||
*,
|
||
pending_head_fetcher: PendingHeadFetcher | None = None,
|
||
) -> None:
|
||
cfg = settings.load_or_site_config()
|
||
self._bindings = cfg.voice_bindings if cfg else None
|
||
self._assignments: dict[str, str] = {}
|
||
self._lock = Lock()
|
||
self._connections: dict[str, set[WebSocket]] = defaultdict(set)
|
||
self._pending_head_fetcher = pending_head_fetcher
|
||
|
||
@property
|
||
def bindings(self) -> VoiceTerminalBindingIndex | None:
|
||
return self._bindings
|
||
|
||
def resolve_terminal(self, camera_ids: list[str]) -> str | None:
|
||
if self._bindings is None:
|
||
return None
|
||
return self._bindings.resolve_terminal(camera_ids)
|
||
|
||
def get_assignment(self, terminal_id: str) -> str | None:
|
||
return self._assignments.get(terminal_id.strip())
|
||
|
||
async def notify_start(self, terminal_id: str, surgery_id: str) -> None:
|
||
tid = terminal_id.strip()
|
||
if not tid:
|
||
return
|
||
payload = {
|
||
"type": "voice_assignment",
|
||
"action": "start",
|
||
"surgery_id": surgery_id,
|
||
}
|
||
async with self._lock:
|
||
self._assignments[tid] = surgery_id
|
||
await self._broadcast(tid, payload)
|
||
logger.info(
|
||
"Voice terminal {} assigned surgery {} (start push)",
|
||
tid,
|
||
surgery_id,
|
||
)
|
||
self.schedule_notify_pending_head(tid, surgery_id)
|
||
|
||
def schedule_notify_pending_head(self, terminal_id: str, surgery_id: str) -> None:
|
||
"""异步推送队首(含 TTS),不阻塞调用方。"""
|
||
tid = terminal_id.strip()
|
||
sid = (surgery_id or "").strip()
|
||
if not tid or not sid:
|
||
return
|
||
asyncio.create_task(self._notify_pending_head_safe(tid, sid))
|
||
|
||
async def notify_end(self, terminal_id: str | None, surgery_id: str) -> None:
|
||
if not terminal_id:
|
||
return
|
||
tid = terminal_id.strip()
|
||
if not tid:
|
||
return
|
||
payload = {
|
||
"type": "voice_assignment",
|
||
"action": "end",
|
||
"surgery_id": surgery_id,
|
||
}
|
||
async with self._lock:
|
||
if self._assignments.get(tid) == surgery_id:
|
||
del self._assignments[tid]
|
||
await self._broadcast(tid, payload)
|
||
logger.info(
|
||
"Voice terminal {} released surgery {} (end push)",
|
||
tid,
|
||
surgery_id,
|
||
)
|
||
|
||
async def notify_pending_head(self, terminal_id: str, surgery_id: str) -> None:
|
||
"""向终端推送当前 FIFO 队首(含 TTS),无队首时推送 voice_pending_empty。"""
|
||
fetcher = self._pending_head_fetcher
|
||
tid = terminal_id.strip()
|
||
sid = (surgery_id or "").strip()
|
||
if not fetcher or not tid or not sid:
|
||
return
|
||
try:
|
||
payload = await fetcher(sid)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"voice_pending 构建失败 surgery_id={} terminal_id={}: {}",
|
||
sid,
|
||
tid,
|
||
exc,
|
||
)
|
||
return
|
||
if payload is None:
|
||
await self._broadcast(
|
||
tid,
|
||
{"type": "voice_pending_empty", "surgery_id": sid},
|
||
)
|
||
return
|
||
try:
|
||
data = payload.model_dump(mode="json")
|
||
except Exception as exc:
|
||
logger.warning("voice_pending 序列化失败 surgery_id={}: {}", sid, exc)
|
||
return
|
||
data["type"] = "voice_pending"
|
||
await self._broadcast(tid, data)
|
||
|
||
async def _notify_pending_head_safe(
|
||
self, terminal_id: str, surgery_id: str
|
||
) -> None:
|
||
try:
|
||
await self.notify_pending_head(terminal_id, surgery_id)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"后台 voice_pending 推送失败 terminal_id={} surgery_id={}: {}",
|
||
terminal_id,
|
||
surgery_id,
|
||
exc,
|
||
)
|
||
|
||
async def handle_websocket(self, websocket: WebSocket, terminal_id: str) -> None:
|
||
tid = terminal_id.strip()
|
||
if not tid:
|
||
await websocket.close(code=4400)
|
||
return
|
||
await websocket.accept()
|
||
async with self._lock:
|
||
self._connections[tid].add(websocket)
|
||
try:
|
||
# 连接后立即推送当前 assignment,避免错过 start
|
||
sid = self._assignments.get(tid)
|
||
if sid:
|
||
await websocket.send_text(
|
||
json.dumps(
|
||
{
|
||
"type": "voice_assignment",
|
||
"action": "start",
|
||
"surgery_id": sid,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
self.schedule_notify_pending_head(tid, sid)
|
||
# 不能用 receive_text():桌面端 websocket-client 会发 ping/二进制控制帧,
|
||
# ASGI 可能呈现为无 "text" 的 websocket.receive,receive_text 会 KeyError 并掐断连接。
|
||
while True:
|
||
message = await websocket.receive()
|
||
if message["type"] == "websocket.disconnect":
|
||
break
|
||
except WebSocketDisconnect:
|
||
pass
|
||
finally:
|
||
async with self._lock:
|
||
conns = self._connections.get(tid)
|
||
if conns:
|
||
conns.discard(websocket)
|
||
if not conns:
|
||
del self._connections[tid]
|
||
|
||
async def _broadcast(self, terminal_id: str, payload: dict) -> None:
|
||
text = json.dumps(payload, ensure_ascii=False)
|
||
async with self._lock:
|
||
targets = list(self._connections.get(terminal_id, ()))
|
||
dead: list[WebSocket] = []
|
||
for ws in targets:
|
||
try:
|
||
await ws.send_text(text)
|
||
except Exception as exc:
|
||
logger.debug("voice terminal ws send failed: {}", exc)
|
||
dead.append(ws)
|
||
if dead:
|
||
async with self._lock:
|
||
for ws in dead:
|
||
self._connections[terminal_id].discard(ws)
|