ver0.1
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import array
|
||||
import io
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -9,12 +10,57 @@ import wave
|
||||
from typing import Final
|
||||
|
||||
_BAIDU_RATE: Final[int] = 16000
|
||||
# 诊室麦克风常见音量偏小,百度短语音 3301「语音质量错误」多与有效幅度过低有关。
|
||||
_NORM_TARGET_PEAK: Final[int] = 12000
|
||||
_NORM_MAX_GAIN: Final[float] = 80.0
|
||||
|
||||
|
||||
class WavDecodeError(ValueError):
|
||||
"""Uploaded bytes are not a valid WAV or cannot be converted."""
|
||||
|
||||
|
||||
def pcm_s16le_to_wav_bytes(pcm: bytes, *, sample_rate: int = _BAIDU_RATE) -> bytes:
|
||||
"""将 raw s16le mono PCM 打成标准 WAV,供百度 ``format=wav`` 重试等场景。"""
|
||||
if not pcm:
|
||||
raise WavDecodeError("Empty PCM")
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(pcm)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def normalize_pcm_s16le_for_baidu(pcm: bytes) -> bytes:
|
||||
"""提升过弱信号幅度,降低 err_no=3301(speech quality)概率;已足够响的音频不改。"""
|
||||
if len(pcm) < 2 or len(pcm) % 2 != 0:
|
||||
return pcm
|
||||
samples = array.array("h")
|
||||
samples.frombytes(pcm)
|
||||
if not samples:
|
||||
return pcm
|
||||
peak = 0
|
||||
for s in samples:
|
||||
a = abs(int(s))
|
||||
if a > peak:
|
||||
peak = a
|
||||
if peak == 0 or peak >= _NORM_TARGET_PEAK:
|
||||
return pcm
|
||||
scale = min(_NORM_MAX_GAIN, float(_NORM_TARGET_PEAK) / float(peak))
|
||||
if scale <= 1.0:
|
||||
return pcm
|
||||
out = array.array("h")
|
||||
for s in samples:
|
||||
v = int(round(float(s) * scale))
|
||||
if v > 32767:
|
||||
v = 32767
|
||||
elif v < -32768:
|
||||
v = -32768
|
||||
out.append(v)
|
||||
return out.tobytes()
|
||||
|
||||
|
||||
def wav_bytes_to_pcm16k_mono_s16le(wav_bytes: bytes) -> bytes:
|
||||
"""
|
||||
Prefer ffmpeg for arbitrary channel count / sample rate.
|
||||
@@ -57,7 +103,7 @@ def _ffmpeg_to_pcm16k(wav_bytes: bytes, ffmpeg: str) -> bytes:
|
||||
raise WavDecodeError(f"ffmpeg wav decode failed: {err or proc.returncode}")
|
||||
if not proc.stdout:
|
||||
raise WavDecodeError("ffmpeg produced empty PCM")
|
||||
return proc.stdout
|
||||
return normalize_pcm_s16le_for_baidu(proc.stdout)
|
||||
|
||||
|
||||
def _stdlib_wave_to_pcm16k(wav_bytes: bytes) -> bytes:
|
||||
@@ -96,6 +142,6 @@ def _stdlib_wave_to_pcm16k(wav_bytes: bytes) -> bytes:
|
||||
l_s, r_s = struct.unpack("<hh", chunk)
|
||||
m = max(min((l_s + r_s) // 2, 32767), -32768)
|
||||
out.extend(struct.pack("<h", m))
|
||||
return bytes(out)
|
||||
return normalize_pcm_s16le_for_baidu(bytes(out))
|
||||
|
||||
return raw
|
||||
return normalize_pcm_s16le_for_baidu(raw)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from aip import AipSpeech
|
||||
|
||||
from app.config import Settings, settings as _default_settings
|
||||
from app.services.audio_wav import pcm_s16le_to_wav_bytes
|
||||
|
||||
|
||||
class BaiduSpeechNotConfiguredError(RuntimeError):
|
||||
@@ -60,6 +61,31 @@ class BaiduSpeechService:
|
||||
merged["dev_pid"] = int(self._s.baidu_speech_asr_dev_pid)
|
||||
return self._client_or_raise().asr(speech, format, rate, merged)
|
||||
|
||||
def asr_16k_mono_pcm_or_wav_fallback(
|
||||
self,
|
||||
pcm_s16le: bytes,
|
||||
*,
|
||||
rate: int = 16000,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""先按 raw PCM 识别;若返回 err_no=3301(语音质量错误),再用 WAV 封装重试一次。
|
||||
|
||||
部分环境下 PCM 与 WAV 路径对边界样本表现不一致,重试可提高成功率。
|
||||
"""
|
||||
r = self.asr(pcm_s16le, "pcm", rate, options)
|
||||
if not isinstance(r, dict):
|
||||
return r
|
||||
if r.get("err_no") != 3301:
|
||||
return r
|
||||
if len(pcm_s16le) < 1000:
|
||||
return r
|
||||
try:
|
||||
wav = pcm_s16le_to_wav_bytes(pcm_s16le, sample_rate=rate)
|
||||
except Exception:
|
||||
return r
|
||||
r2 = self.asr(wav, "wav", rate, options)
|
||||
return r2 if isinstance(r2, dict) else r
|
||||
|
||||
def synthesis(
|
||||
self,
|
||||
text: str,
|
||||
|
||||
@@ -111,6 +111,10 @@ class SurgeryPipeline:
|
||||
pending = self._sessions.next_pending_confirmation(surgery_id)
|
||||
if pending is None:
|
||||
return None
|
||||
queue_len = self._sessions.pending_queue_pending_count(surgery_id)
|
||||
qpos = self._sessions.pending_queue_position_1based(surgery_id, pending.id)
|
||||
if qpos is None or qpos < 1:
|
||||
qpos = 1
|
||||
mp3 = await run_in_threadpool(
|
||||
self._voice.synthesize_prompt_to_mp3,
|
||||
pending.prompt_text,
|
||||
@@ -119,6 +123,9 @@ class SurgeryPipeline:
|
||||
return SurgeryPendingConfirmationResponse(
|
||||
surgery_id=surgery_id,
|
||||
confirmation_id=pending.id,
|
||||
pending_queue_length=max(1, queue_len),
|
||||
pending_queue_position=qpos,
|
||||
pending_cumulative_ordinal=max(1, pending.enqueue_ordinal),
|
||||
prompt_text=pending.prompt_text,
|
||||
prompt_audio_mp3_base64=b64,
|
||||
options=[
|
||||
|
||||
@@ -14,6 +14,8 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from app.services.voice_terminal_hub import VoiceTerminalHub
|
||||
|
||||
from app.baked import pipeline as bp
|
||||
from app.services.consumable_vision_algorithm import (
|
||||
PredictionCandidate,
|
||||
@@ -56,8 +58,13 @@ class VisionClassificationHandler:
|
||||
self,
|
||||
*,
|
||||
registry: SurgerySessionRegistry,
|
||||
voice_terminal_hub: VoiceTerminalHub | None = None,
|
||||
) -> None:
|
||||
self._registry = registry
|
||||
self._voice_hub = voice_terminal_hub
|
||||
|
||||
def attach_voice_terminal_hub(self, hub: VoiceTerminalHub | None) -> None:
|
||||
self._voice_hub = hub
|
||||
|
||||
def _append_vision_consumption_window_if_ready(
|
||||
self,
|
||||
@@ -212,3 +219,7 @@ class VisionClassificationHandler:
|
||||
confirmation_id=cid,
|
||||
doctor_id=bp.VIDEO_RESULT_DOCTOR_ID,
|
||||
)
|
||||
hub = self._voice_hub
|
||||
vtid = (state.voice_terminal_id or "").strip()
|
||||
if hub is not None and vtid and surgery_id:
|
||||
hub.schedule_notify_pending_head(vtid, surgery_id)
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.services.consumable_vision_algorithm import (
|
||||
from app.services.video.archive_persister import ArchivePersister
|
||||
from app.services.video.backend_resolver import BackendResolver
|
||||
from app.services.video.classification_handler import VisionClassificationHandler
|
||||
from app.services.voice_terminal_hub import VoiceTerminalHub
|
||||
from app.services.video.hikvision_runtime import HikvisionInitRefCount, HikvisionRuntime
|
||||
from app.services.video.inference_aggregator import WindowInferenceAggregator
|
||||
from app.services.video.session_registry import (
|
||||
@@ -101,6 +102,16 @@ class CameraSessionManager:
|
||||
registry=self._registry,
|
||||
)
|
||||
|
||||
def set_voice_terminal_hub(self, hub: VoiceTerminalHub | None) -> None:
|
||||
self._classifier_handler.attach_voice_terminal_hub(hub)
|
||||
|
||||
def get_voice_terminal_id_if_active(self, surgery_id: str) -> str | None:
|
||||
run = self._registry.get_running(surgery_id)
|
||||
if run is None:
|
||||
return None
|
||||
tid = (run.state.voice_terminal_id or "").strip()
|
||||
return tid or None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ------------------------------------------------------------------
|
||||
@@ -310,6 +321,16 @@ class CameraSessionManager:
|
||||
) -> PendingConsumableConfirmation | None:
|
||||
return self._registry.next_pending_confirmation(surgery_id)
|
||||
|
||||
def pending_queue_pending_count(self, surgery_id: str) -> int:
|
||||
return self._registry.pending_queue_pending_count(surgery_id)
|
||||
|
||||
def pending_queue_position_1based(
|
||||
self, surgery_id: str, confirmation_id: str
|
||||
) -> int | None:
|
||||
return self._registry.pending_queue_position_1based(
|
||||
surgery_id, confirmation_id
|
||||
)
|
||||
|
||||
async def resolve_pending_confirmation(
|
||||
self,
|
||||
surgery_id: str,
|
||||
|
||||
@@ -52,6 +52,8 @@ class PendingConsumableConfirmation:
|
||||
model_top1_confidence: float
|
||||
#: 本轮待确认在解析失败时累计次数(首败 + 重试),供 API 计算 retry_remaining。
|
||||
voice_parse_failures: int = 0
|
||||
#: 本场手术中待确认任务入队时的累计序号(自 1 起,入队时递增)。
|
||||
enqueue_ordinal: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,6 +89,8 @@ class SurgerySessionState:
|
||||
surgery_started_wall: float | None = None
|
||||
#: 术间绑定配置解析出的语音桌面终端 ID;停录时用于推送 end。
|
||||
voice_terminal_id: str | None = None
|
||||
#: 待确认入队已分配到的最大序号(与 ``pending_by_id`` 中 ``enqueue_ordinal`` 一致递增)。
|
||||
pending_ordinal_next: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -203,6 +207,32 @@ class SurgerySessionRegistry:
|
||||
return p
|
||||
return None
|
||||
|
||||
def pending_queue_pending_count(self, surgery_id: str) -> int:
|
||||
"""FIFO 中仍为 pending 的条数(与 ``next_pending_confirmation`` 遍历规则一致)。"""
|
||||
run = self._active.get(surgery_id)
|
||||
if run is None:
|
||||
return 0
|
||||
st = run.state
|
||||
n = 0
|
||||
for cid in st.pending_fifo:
|
||||
p = st.pending_by_id.get(cid)
|
||||
if p is not None and p.status == "pending":
|
||||
n += 1
|
||||
return n
|
||||
|
||||
def pending_queue_position_1based(
|
||||
self, surgery_id: str, confirmation_id: str
|
||||
) -> int | None:
|
||||
"""``confirmation_id`` 在当前 ``pending_fifo`` 中的 1-based 位置(队首为 1)。"""
|
||||
run = self._active.get(surgery_id)
|
||||
if run is None:
|
||||
return None
|
||||
st = run.state
|
||||
try:
|
||||
return st.pending_fifo.index(confirmation_id) + 1
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def resolve_pending_confirmation(
|
||||
self,
|
||||
surgery_id: str,
|
||||
@@ -436,6 +466,8 @@ class SurgerySessionRegistry:
|
||||
return None
|
||||
state.last_detail_monotonic[dedupe_key] = now_m
|
||||
|
||||
state.pending_ordinal_next += 1
|
||||
ordinal = state.pending_ordinal_next
|
||||
confirm_id = str(uuid.uuid4())
|
||||
prompt = build_prompt_text(opts)
|
||||
pending = PendingConsumableConfirmation(
|
||||
@@ -446,6 +478,7 @@ class SurgerySessionRegistry:
|
||||
created_at=datetime.now(timezone.utc),
|
||||
model_top1_label=top_key,
|
||||
model_top1_confidence=top_confidence,
|
||||
enqueue_ordinal=ordinal,
|
||||
)
|
||||
state.pending_by_id[confirm_id] = pending
|
||||
state.pending_fifo.append(confirm_id)
|
||||
|
||||
@@ -183,7 +183,7 @@ def is_rejection_phrase(asr_text: str) -> bool:
|
||||
|
||||
|
||||
def build_prompt_text(options: list[tuple[str, float]]) -> str:
|
||||
parts = ["请确认刚才使用的耗材是下面哪一项。"]
|
||||
parts = ["请确认。"]
|
||||
for i, (name, _conf) in enumerate(options, start=1):
|
||||
parts.append(f"第{i}个,{name}。")
|
||||
return "".join(parts)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
@@ -52,6 +53,7 @@ class VoiceConfirmationService:
|
||||
audits: VoiceAuditRepository,
|
||||
session_factory: async_sessionmaker | None = None,
|
||||
audit_emitter: VoiceAuditEmitter | None = None,
|
||||
on_pending_queue_advanced: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
self._s = settings
|
||||
self._sessions = sessions
|
||||
@@ -64,6 +66,21 @@ class VoiceConfirmationService:
|
||||
audits=audits,
|
||||
session_factory=self._session_factory,
|
||||
)
|
||||
self._on_pending_queue_advanced = on_pending_queue_advanced
|
||||
|
||||
def set_on_pending_queue_advanced(
|
||||
self, cb: Callable[[str], Awaitable[None]] | None
|
||||
) -> None:
|
||||
self._on_pending_queue_advanced = cb
|
||||
|
||||
async def _notify_pending_queue_advanced(self, surgery_id: str) -> None:
|
||||
cb = self._on_pending_queue_advanced
|
||||
if cb is None:
|
||||
return
|
||||
try:
|
||||
await cb(surgery_id)
|
||||
except Exception as exc:
|
||||
logger.warning("on_pending_queue_advanced 回调失败: {}", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TTS:保持对外接口不变
|
||||
@@ -240,6 +257,7 @@ class VoiceConfirmationService:
|
||||
chosen_label=chosen,
|
||||
rejected=rejected,
|
||||
)
|
||||
await self._notify_pending_queue_advanced(surgery_id)
|
||||
final_status = "rejected" if rejected else "recognized"
|
||||
await self._emitter.success(
|
||||
source="wav",
|
||||
@@ -344,6 +362,7 @@ class VoiceConfirmationService:
|
||||
chosen_label=chosen,
|
||||
rejected=rejected,
|
||||
)
|
||||
await self._notify_pending_queue_advanced(surgery_id)
|
||||
final_status = "rejected" if rejected else "recognized"
|
||||
await self._emitter.success(
|
||||
source="text",
|
||||
@@ -448,7 +467,9 @@ class VoiceConfirmationService:
|
||||
session_trace,
|
||||
) -> object:
|
||||
try:
|
||||
return await run_in_threadpool(self._baidu.asr, pcm, "pcm", 16000, None)
|
||||
return await run_in_threadpool(
|
||||
self._baidu.asr_16k_mono_pcm_or_wav_fallback, pcm
|
||||
)
|
||||
except BaiduSpeechNotConfiguredError as exc:
|
||||
raise await self._emitter.fail(
|
||||
source="wav",
|
||||
@@ -600,5 +621,6 @@ class VoiceConfirmationService:
|
||||
include_extra={
|
||||
"confirmation_id": confirmation_id,
|
||||
"retry_remaining": retry_remaining,
|
||||
"asr_text": text,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""语音桌面终端:assignment 状态、WebSocket 推送与 HTTP 轮询兜底。"""
|
||||
"""语音桌面终端:assignment 状态、WebSocket 推送与 HTTP 拉取兜底。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import WebSocket
|
||||
from loguru import logger
|
||||
@@ -14,6 +16,8 @@ from starlette.websockets import WebSocketDisconnect
|
||||
from app.config import Settings
|
||||
from app.services.voice_terminal_binding import VoiceTerminalBindingIndex
|
||||
|
||||
PendingHeadFetcher = Callable[[str], Awaitable[Any]]
|
||||
|
||||
|
||||
async def assign_voice_terminal_after_recording_started(
|
||||
hub: VoiceTerminalHub,
|
||||
@@ -45,12 +49,18 @@ async def assign_voice_terminal_after_recording_started(
|
||||
class VoiceTerminalHub:
|
||||
"""进程内终端连接与当前手术分配(多 worker 需另行同步)。"""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
*,
|
||||
pending_head_fetcher: PendingHeadFetcher | None = None,
|
||||
) -> None:
|
||||
cfg = settings.load_or_site_config()
|
||||
self._bindings = cfg.voice_bindings if cfg else None
|
||||
self._assignments: dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
self._connections: dict[str, set[WebSocket]] = defaultdict(set)
|
||||
self._pending_head_fetcher = pending_head_fetcher
|
||||
|
||||
@property
|
||||
def bindings(self) -> VoiceTerminalBindingIndex | None:
|
||||
@@ -81,6 +91,15 @@ class VoiceTerminalHub:
|
||||
tid,
|
||||
surgery_id,
|
||||
)
|
||||
self.schedule_notify_pending_head(tid, surgery_id)
|
||||
|
||||
def schedule_notify_pending_head(self, terminal_id: str, surgery_id: str) -> None:
|
||||
"""异步推送队首(含 TTS),不阻塞调用方。"""
|
||||
tid = terminal_id.strip()
|
||||
sid = (surgery_id or "").strip()
|
||||
if not tid or not sid:
|
||||
return
|
||||
asyncio.create_task(self._notify_pending_head_safe(tid, sid))
|
||||
|
||||
async def notify_end(self, terminal_id: str | None, surgery_id: str) -> None:
|
||||
if not terminal_id:
|
||||
@@ -103,6 +122,50 @@ class VoiceTerminalHub:
|
||||
surgery_id,
|
||||
)
|
||||
|
||||
async def notify_pending_head(self, terminal_id: str, surgery_id: str) -> None:
|
||||
"""向终端推送当前 FIFO 队首(含 TTS),无队首时推送 voice_pending_empty。"""
|
||||
fetcher = self._pending_head_fetcher
|
||||
tid = terminal_id.strip()
|
||||
sid = (surgery_id or "").strip()
|
||||
if not fetcher or not tid or not sid:
|
||||
return
|
||||
try:
|
||||
payload = await fetcher(sid)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"voice_pending 构建失败 surgery_id={} terminal_id={}: {}",
|
||||
sid,
|
||||
tid,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
if payload is None:
|
||||
await self._broadcast(
|
||||
tid,
|
||||
{"type": "voice_pending_empty", "surgery_id": sid},
|
||||
)
|
||||
return
|
||||
try:
|
||||
data = payload.model_dump(mode="json")
|
||||
except Exception as exc:
|
||||
logger.warning("voice_pending 序列化失败 surgery_id={}: {}", sid, exc)
|
||||
return
|
||||
data["type"] = "voice_pending"
|
||||
await self._broadcast(tid, data)
|
||||
|
||||
async def _notify_pending_head_safe(
|
||||
self, terminal_id: str, surgery_id: str
|
||||
) -> None:
|
||||
try:
|
||||
await self.notify_pending_head(terminal_id, surgery_id)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"后台 voice_pending 推送失败 terminal_id={} surgery_id={}: {}",
|
||||
terminal_id,
|
||||
surgery_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
async def handle_websocket(self, websocket: WebSocket, terminal_id: str) -> None:
|
||||
tid = terminal_id.strip()
|
||||
if not tid:
|
||||
@@ -125,6 +188,7 @@ class VoiceTerminalHub:
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
self.schedule_notify_pending_head(tid, sid)
|
||||
# 不能用 receive_text():桌面端 websocket-client 会发 ping/二进制控制帧,
|
||||
# ASGI 可能呈现为无 "text" 的 websocket.receive,receive_text 会 KeyError 并掐断连接。
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user