ver0.1
This commit is contained in:
24
app/api.py
24
app/api.py
@@ -42,6 +42,11 @@ from app.surgery_errors import SurgeryPipelineError
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 上传 WAV 后 ASR/解析失败:HTTP 200 + status=failed,待确认项仍留在 FIFO 队首,便于桌面端重试。
|
||||
_RECOVERABLE_VOICE_RESOLVE_CODES = frozenset(
|
||||
{"VOICE_ASR_FAILED", "VOICE_TEXT_EMPTY", "VOICE_PARSE_FAILED"}
|
||||
)
|
||||
|
||||
|
||||
def _pipeline_error_detail(exc: SurgeryPipelineError, surgery_id: str) -> dict:
|
||||
d: dict = {
|
||||
@@ -367,7 +372,7 @@ async def get_surgery_result(
|
||||
tags=["client"],
|
||||
summary="拉取待确认耗材(含 TTS 音频)",
|
||||
description=(
|
||||
"返回当前 FIFO 队首的一条低置信度识别;"
|
||||
"返回当前 FIFO 队首的一条低置信度识别;`pending_queue_length` 为仍排队中的 pending 条数(含本条)。"
|
||||
"响应内 `prompt_audio_mp3_base64` 为与 `prompt_text` 一致的 MP3(Base64),客户端可直接解码播放。"
|
||||
"无待确认项时返回 404;提示文本为空为 422;未配置百度或 TTS 失败为 503(不返回空音频兜底)。"
|
||||
"医生确认后请使用 `POST .../resolve` 上传 WAV。"
|
||||
@@ -416,6 +421,7 @@ async def get_pending_consumable_confirmation(
|
||||
"multipart/form-data 上传单个 WAV 文件(字段名 `audio`)。"
|
||||
"服务端将音频存入 MinIO、调用百度 ASR 识别、解析候选项并完成确认。"
|
||||
"解析并确认后记一条消耗明细;若语音表示否认全部候选则不记消耗。"
|
||||
"ASR/解析可重试失败时仍返回 HTTP 200,`status`=`failed`,队首待确认项不弹出,便于桌面端重试。"
|
||||
),
|
||||
)
|
||||
async def resolve_pending_consumable_confirmation(
|
||||
@@ -467,6 +473,21 @@ async def resolve_pending_consumable_confirmation(
|
||||
content_type=audio.content_type,
|
||||
)
|
||||
except SurgeryPipelineError as exc:
|
||||
if exc.code in _RECOVERABLE_VOICE_RESOLVE_CODES:
|
||||
extra = exc.extra or {}
|
||||
asr_txt = extra.get("asr_text")
|
||||
akey = extra.get("audio_object_key")
|
||||
return SurgeryPendingConfirmationResolveResponse(
|
||||
surgery_id=surgery_id,
|
||||
confirmation_id=confirmation_id,
|
||||
status="failed",
|
||||
message=exc.message,
|
||||
resolved_label=None,
|
||||
rejected=False,
|
||||
asr_text=asr_txt if isinstance(asr_txt, str) else None,
|
||||
audio_object_key=akey if isinstance(akey, str) else None,
|
||||
error_code=exc.code,
|
||||
)
|
||||
_raise_confirmation_http(exc, surgery_id)
|
||||
return SurgeryPendingConfirmationResolveResponse(
|
||||
surgery_id=surgery_id,
|
||||
@@ -477,4 +498,5 @@ async def resolve_pending_consumable_confirmation(
|
||||
rejected=result.rejected,
|
||||
asr_text=result.asr_text,
|
||||
audio_object_key=result.audio_object_key,
|
||||
error_code=None,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from app.baked import algorithm as baked_algorithm
|
||||
@@ -111,6 +111,7 @@ class _ServerGroup(_SettingsGroup):
|
||||
_FIELDS = (
|
||||
"server_host",
|
||||
"server_port",
|
||||
"server_reload",
|
||||
)
|
||||
|
||||
_PACKAGE_DIR = Path(__file__).resolve().parent
|
||||
@@ -132,6 +133,11 @@ class Settings(BaseSettings):
|
||||
|
||||
server_host: str = "0.0.0.0"
|
||||
server_port: int = Field(default=38080, ge=1, le=65535)
|
||||
#: 开发用:等价 ``uvicorn --reload``,Python 代码变更时重载进程(勿在生产开启)。
|
||||
server_reload: bool = Field(
|
||||
default=False,
|
||||
validation_alias=AliasChoices("server_reload", "UVICORN_RELOAD"),
|
||||
)
|
||||
|
||||
video_default_backend: Literal["rtsp", "hikvision_sdk", "auto"] = "rtsp"
|
||||
video_camera_backend_overrides_json: str = ""
|
||||
|
||||
@@ -98,7 +98,18 @@ def build_container(
|
||||
voice_confirmation=voice,
|
||||
session_factory=sf,
|
||||
)
|
||||
voice_hub = VoiceTerminalHub(s)
|
||||
voice_hub = VoiceTerminalHub(
|
||||
s,
|
||||
pending_head_fetcher=pipeline.get_pending_confirmation_for_client,
|
||||
)
|
||||
|
||||
async def _on_pending_queue_advanced(surgery_id: str) -> None:
|
||||
tid = camera_mgr.get_voice_terminal_id_if_active(surgery_id)
|
||||
if tid:
|
||||
voice_hub.schedule_notify_pending_head(tid, surgery_id)
|
||||
|
||||
voice.set_on_pending_queue_advanced(_on_pending_queue_advanced)
|
||||
camera_mgr.set_voice_terminal_hub(voice_hub)
|
||||
return AppContainer(
|
||||
settings=s,
|
||||
consumable_vision_algorithm_service=vision,
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
{
|
||||
"video_rtsp_urls": {
|
||||
"or-cam-01": "rtsp://127.0.0.1:18554/demo1"
|
||||
"or-cam-01": "rtsp://admin:Aa183137@192.168.3.2:554/Streaming/Channels/101",
|
||||
"or-cam-02": "rtsp://admin:Aa183137@192.168.3.3:554/Streaming/Channels/101",
|
||||
"or-cam-03": "rtsp://admin:Aa183137@192.168.3.4:554/Streaming/Channels/101",
|
||||
"or-cam-04": "rtsp://admin:Aa183137@192.168.3.5:554/Streaming/Channels/101"
|
||||
},
|
||||
"voice_or_room_bindings": [
|
||||
{
|
||||
"camera_ids": [
|
||||
"or-cam-01",
|
||||
"or-cam-02"
|
||||
"or-cam-02",
|
||||
"or-cam-03",
|
||||
"or-cam-04"
|
||||
],
|
||||
"or_room_id": "OR-DEMO",
|
||||
"or_room_id": "OR-TEST",
|
||||
"voice_terminal_id": "desktop-1"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -15,7 +15,12 @@ class SurgeryStartRequest(BaseModel):
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"surgery_id": "123456",
|
||||
"camera_ids": ["or-cam-01", "or-cam-02"],
|
||||
"camera_ids": [
|
||||
"or-cam-01",
|
||||
"or-cam-02",
|
||||
"or-cam-03",
|
||||
"or-cam-04",
|
||||
],
|
||||
"candidate_consumables": ["纱布", "缝线", "止血钳"],
|
||||
}
|
||||
}
|
||||
@@ -164,6 +169,18 @@ class SurgeryPendingConfirmationResponse(BaseModel):
|
||||
|
||||
surgery_id: str
|
||||
confirmation_id: str
|
||||
pending_queue_length: int = Field(
|
||||
ge=1,
|
||||
description="本台手术待确认 FIFO 中仍为 pending 的条数(含本条),用于客户端按序播报。",
|
||||
)
|
||||
pending_queue_position: int = Field(
|
||||
ge=1,
|
||||
description="本条在当前 FIFO 中的排队序号(1-based,队首为 1,与队尾相隔 pending_queue_length-1 条等待)。",
|
||||
)
|
||||
pending_cumulative_ordinal: int = Field(
|
||||
ge=1,
|
||||
description="本场手术中待确认任务自入队以来的累计序号(第几条入队任务)。",
|
||||
)
|
||||
prompt_text: str = Field(description="可直接用于展示或无障碍朗读的话术(与 MP3 内容一致)。")
|
||||
prompt_audio_mp3_base64: str = Field(
|
||||
description=(
|
||||
@@ -181,8 +198,17 @@ class SurgeryPendingConfirmationResponse(BaseModel):
|
||||
class SurgeryPendingConfirmationResolveResponse(BaseModel):
|
||||
surgery_id: str
|
||||
confirmation_id: str
|
||||
status: str = Field(description="accepted")
|
||||
status: str = Field(
|
||||
description=(
|
||||
"``accepted``:已确认或已否认并结案;"
|
||||
"``failed``:ASR/解析等可重试失败,队首待确认项未移除。"
|
||||
),
|
||||
)
|
||||
message: str
|
||||
error_code: str | None = Field(
|
||||
default=None,
|
||||
description="仅 status=failed 时与错误码一致(如 VOICE_ASR_FAILED)。",
|
||||
)
|
||||
resolved_label: str | None = Field(
|
||||
default=None,
|
||||
description="解析并确认后的耗材名称;否认全部候选时为 null。",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import array
|
||||
import io
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -9,12 +10,57 @@ import wave
|
||||
from typing import Final
|
||||
|
||||
_BAIDU_RATE: Final[int] = 16000
|
||||
# 诊室麦克风常见音量偏小,百度短语音 3301「语音质量错误」多与有效幅度过低有关。
|
||||
_NORM_TARGET_PEAK: Final[int] = 12000
|
||||
_NORM_MAX_GAIN: Final[float] = 80.0
|
||||
|
||||
|
||||
class WavDecodeError(ValueError):
|
||||
"""Uploaded bytes are not a valid WAV or cannot be converted."""
|
||||
|
||||
|
||||
def pcm_s16le_to_wav_bytes(pcm: bytes, *, sample_rate: int = _BAIDU_RATE) -> bytes:
|
||||
"""将 raw s16le mono PCM 打成标准 WAV,供百度 ``format=wav`` 重试等场景。"""
|
||||
if not pcm:
|
||||
raise WavDecodeError("Empty PCM")
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(pcm)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def normalize_pcm_s16le_for_baidu(pcm: bytes) -> bytes:
|
||||
"""提升过弱信号幅度,降低 err_no=3301(speech quality)概率;已足够响的音频不改。"""
|
||||
if len(pcm) < 2 or len(pcm) % 2 != 0:
|
||||
return pcm
|
||||
samples = array.array("h")
|
||||
samples.frombytes(pcm)
|
||||
if not samples:
|
||||
return pcm
|
||||
peak = 0
|
||||
for s in samples:
|
||||
a = abs(int(s))
|
||||
if a > peak:
|
||||
peak = a
|
||||
if peak == 0 or peak >= _NORM_TARGET_PEAK:
|
||||
return pcm
|
||||
scale = min(_NORM_MAX_GAIN, float(_NORM_TARGET_PEAK) / float(peak))
|
||||
if scale <= 1.0:
|
||||
return pcm
|
||||
out = array.array("h")
|
||||
for s in samples:
|
||||
v = int(round(float(s) * scale))
|
||||
if v > 32767:
|
||||
v = 32767
|
||||
elif v < -32768:
|
||||
v = -32768
|
||||
out.append(v)
|
||||
return out.tobytes()
|
||||
|
||||
|
||||
def wav_bytes_to_pcm16k_mono_s16le(wav_bytes: bytes) -> bytes:
|
||||
"""
|
||||
Prefer ffmpeg for arbitrary channel count / sample rate.
|
||||
@@ -57,7 +103,7 @@ def _ffmpeg_to_pcm16k(wav_bytes: bytes, ffmpeg: str) -> bytes:
|
||||
raise WavDecodeError(f"ffmpeg wav decode failed: {err or proc.returncode}")
|
||||
if not proc.stdout:
|
||||
raise WavDecodeError("ffmpeg produced empty PCM")
|
||||
return proc.stdout
|
||||
return normalize_pcm_s16le_for_baidu(proc.stdout)
|
||||
|
||||
|
||||
def _stdlib_wave_to_pcm16k(wav_bytes: bytes) -> bytes:
|
||||
@@ -96,6 +142,6 @@ def _stdlib_wave_to_pcm16k(wav_bytes: bytes) -> bytes:
|
||||
l_s, r_s = struct.unpack("<hh", chunk)
|
||||
m = max(min((l_s + r_s) // 2, 32767), -32768)
|
||||
out.extend(struct.pack("<h", m))
|
||||
return bytes(out)
|
||||
return normalize_pcm_s16le_for_baidu(bytes(out))
|
||||
|
||||
return raw
|
||||
return normalize_pcm_s16le_for_baidu(raw)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from aip import AipSpeech
|
||||
|
||||
from app.config import Settings, settings as _default_settings
|
||||
from app.services.audio_wav import pcm_s16le_to_wav_bytes
|
||||
|
||||
|
||||
class BaiduSpeechNotConfiguredError(RuntimeError):
|
||||
@@ -60,6 +61,31 @@ class BaiduSpeechService:
|
||||
merged["dev_pid"] = int(self._s.baidu_speech_asr_dev_pid)
|
||||
return self._client_or_raise().asr(speech, format, rate, merged)
|
||||
|
||||
def asr_16k_mono_pcm_or_wav_fallback(
|
||||
self,
|
||||
pcm_s16le: bytes,
|
||||
*,
|
||||
rate: int = 16000,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""先按 raw PCM 识别;若返回 err_no=3301(语音质量错误),再用 WAV 封装重试一次。
|
||||
|
||||
部分环境下 PCM 与 WAV 路径对边界样本表现不一致,重试可提高成功率。
|
||||
"""
|
||||
r = self.asr(pcm_s16le, "pcm", rate, options)
|
||||
if not isinstance(r, dict):
|
||||
return r
|
||||
if r.get("err_no") != 3301:
|
||||
return r
|
||||
if len(pcm_s16le) < 1000:
|
||||
return r
|
||||
try:
|
||||
wav = pcm_s16le_to_wav_bytes(pcm_s16le, sample_rate=rate)
|
||||
except Exception:
|
||||
return r
|
||||
r2 = self.asr(wav, "wav", rate, options)
|
||||
return r2 if isinstance(r2, dict) else r
|
||||
|
||||
def synthesis(
|
||||
self,
|
||||
text: str,
|
||||
|
||||
@@ -111,6 +111,10 @@ class SurgeryPipeline:
|
||||
pending = self._sessions.next_pending_confirmation(surgery_id)
|
||||
if pending is None:
|
||||
return None
|
||||
queue_len = self._sessions.pending_queue_pending_count(surgery_id)
|
||||
qpos = self._sessions.pending_queue_position_1based(surgery_id, pending.id)
|
||||
if qpos is None or qpos < 1:
|
||||
qpos = 1
|
||||
mp3 = await run_in_threadpool(
|
||||
self._voice.synthesize_prompt_to_mp3,
|
||||
pending.prompt_text,
|
||||
@@ -119,6 +123,9 @@ class SurgeryPipeline:
|
||||
return SurgeryPendingConfirmationResponse(
|
||||
surgery_id=surgery_id,
|
||||
confirmation_id=pending.id,
|
||||
pending_queue_length=max(1, queue_len),
|
||||
pending_queue_position=qpos,
|
||||
pending_cumulative_ordinal=max(1, pending.enqueue_ordinal),
|
||||
prompt_text=pending.prompt_text,
|
||||
prompt_audio_mp3_base64=b64,
|
||||
options=[
|
||||
|
||||
@@ -14,6 +14,8 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from app.services.voice_terminal_hub import VoiceTerminalHub
|
||||
|
||||
from app.baked import pipeline as bp
|
||||
from app.services.consumable_vision_algorithm import (
|
||||
PredictionCandidate,
|
||||
@@ -56,8 +58,13 @@ class VisionClassificationHandler:
|
||||
self,
|
||||
*,
|
||||
registry: SurgerySessionRegistry,
|
||||
voice_terminal_hub: VoiceTerminalHub | None = None,
|
||||
) -> None:
|
||||
self._registry = registry
|
||||
self._voice_hub = voice_terminal_hub
|
||||
|
||||
def attach_voice_terminal_hub(self, hub: VoiceTerminalHub | None) -> None:
|
||||
self._voice_hub = hub
|
||||
|
||||
def _append_vision_consumption_window_if_ready(
|
||||
self,
|
||||
@@ -212,3 +219,7 @@ class VisionClassificationHandler:
|
||||
confirmation_id=cid,
|
||||
doctor_id=bp.VIDEO_RESULT_DOCTOR_ID,
|
||||
)
|
||||
hub = self._voice_hub
|
||||
vtid = (state.voice_terminal_id or "").strip()
|
||||
if hub is not None and vtid and surgery_id:
|
||||
hub.schedule_notify_pending_head(vtid, surgery_id)
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.services.consumable_vision_algorithm import (
|
||||
from app.services.video.archive_persister import ArchivePersister
|
||||
from app.services.video.backend_resolver import BackendResolver
|
||||
from app.services.video.classification_handler import VisionClassificationHandler
|
||||
from app.services.voice_terminal_hub import VoiceTerminalHub
|
||||
from app.services.video.hikvision_runtime import HikvisionInitRefCount, HikvisionRuntime
|
||||
from app.services.video.inference_aggregator import WindowInferenceAggregator
|
||||
from app.services.video.session_registry import (
|
||||
@@ -101,6 +102,16 @@ class CameraSessionManager:
|
||||
registry=self._registry,
|
||||
)
|
||||
|
||||
def set_voice_terminal_hub(self, hub: VoiceTerminalHub | None) -> None:
|
||||
self._classifier_handler.attach_voice_terminal_hub(hub)
|
||||
|
||||
def get_voice_terminal_id_if_active(self, surgery_id: str) -> str | None:
|
||||
run = self._registry.get_running(surgery_id)
|
||||
if run is None:
|
||||
return None
|
||||
tid = (run.state.voice_terminal_id or "").strip()
|
||||
return tid or None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ------------------------------------------------------------------
|
||||
@@ -310,6 +321,16 @@ class CameraSessionManager:
|
||||
) -> PendingConsumableConfirmation | None:
|
||||
return self._registry.next_pending_confirmation(surgery_id)
|
||||
|
||||
def pending_queue_pending_count(self, surgery_id: str) -> int:
|
||||
return self._registry.pending_queue_pending_count(surgery_id)
|
||||
|
||||
def pending_queue_position_1based(
|
||||
self, surgery_id: str, confirmation_id: str
|
||||
) -> int | None:
|
||||
return self._registry.pending_queue_position_1based(
|
||||
surgery_id, confirmation_id
|
||||
)
|
||||
|
||||
async def resolve_pending_confirmation(
|
||||
self,
|
||||
surgery_id: str,
|
||||
|
||||
@@ -52,6 +52,8 @@ class PendingConsumableConfirmation:
|
||||
model_top1_confidence: float
|
||||
#: 本轮待确认在解析失败时累计次数(首败 + 重试),供 API 计算 retry_remaining。
|
||||
voice_parse_failures: int = 0
|
||||
#: 本场手术中待确认任务入队时的累计序号(自 1 起,入队时递增)。
|
||||
enqueue_ordinal: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,6 +89,8 @@ class SurgerySessionState:
|
||||
surgery_started_wall: float | None = None
|
||||
#: 术间绑定配置解析出的语音桌面终端 ID;停录时用于推送 end。
|
||||
voice_terminal_id: str | None = None
|
||||
#: 待确认入队已分配到的最大序号(与 ``pending_by_id`` 中 ``enqueue_ordinal`` 一致递增)。
|
||||
pending_ordinal_next: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -203,6 +207,32 @@ class SurgerySessionRegistry:
|
||||
return p
|
||||
return None
|
||||
|
||||
def pending_queue_pending_count(self, surgery_id: str) -> int:
|
||||
"""FIFO 中仍为 pending 的条数(与 ``next_pending_confirmation`` 遍历规则一致)。"""
|
||||
run = self._active.get(surgery_id)
|
||||
if run is None:
|
||||
return 0
|
||||
st = run.state
|
||||
n = 0
|
||||
for cid in st.pending_fifo:
|
||||
p = st.pending_by_id.get(cid)
|
||||
if p is not None and p.status == "pending":
|
||||
n += 1
|
||||
return n
|
||||
|
||||
def pending_queue_position_1based(
|
||||
self, surgery_id: str, confirmation_id: str
|
||||
) -> int | None:
|
||||
"""``confirmation_id`` 在当前 ``pending_fifo`` 中的 1-based 位置(队首为 1)。"""
|
||||
run = self._active.get(surgery_id)
|
||||
if run is None:
|
||||
return None
|
||||
st = run.state
|
||||
try:
|
||||
return st.pending_fifo.index(confirmation_id) + 1
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def resolve_pending_confirmation(
|
||||
self,
|
||||
surgery_id: str,
|
||||
@@ -436,6 +466,8 @@ class SurgerySessionRegistry:
|
||||
return None
|
||||
state.last_detail_monotonic[dedupe_key] = now_m
|
||||
|
||||
state.pending_ordinal_next += 1
|
||||
ordinal = state.pending_ordinal_next
|
||||
confirm_id = str(uuid.uuid4())
|
||||
prompt = build_prompt_text(opts)
|
||||
pending = PendingConsumableConfirmation(
|
||||
@@ -446,6 +478,7 @@ class SurgerySessionRegistry:
|
||||
created_at=datetime.now(timezone.utc),
|
||||
model_top1_label=top_key,
|
||||
model_top1_confidence=top_confidence,
|
||||
enqueue_ordinal=ordinal,
|
||||
)
|
||||
state.pending_by_id[confirm_id] = pending
|
||||
state.pending_fifo.append(confirm_id)
|
||||
|
||||
@@ -183,7 +183,7 @@ def is_rejection_phrase(asr_text: str) -> bool:
|
||||
|
||||
|
||||
def build_prompt_text(options: list[tuple[str, float]]) -> str:
|
||||
parts = ["请确认刚才使用的耗材是下面哪一项。"]
|
||||
parts = ["请确认。"]
|
||||
for i, (name, _conf) in enumerate(options, start=1):
|
||||
parts.append(f"第{i}个,{name}。")
|
||||
return "".join(parts)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
@@ -52,6 +53,7 @@ class VoiceConfirmationService:
|
||||
audits: VoiceAuditRepository,
|
||||
session_factory: async_sessionmaker | None = None,
|
||||
audit_emitter: VoiceAuditEmitter | None = None,
|
||||
on_pending_queue_advanced: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
self._s = settings
|
||||
self._sessions = sessions
|
||||
@@ -64,6 +66,21 @@ class VoiceConfirmationService:
|
||||
audits=audits,
|
||||
session_factory=self._session_factory,
|
||||
)
|
||||
self._on_pending_queue_advanced = on_pending_queue_advanced
|
||||
|
||||
def set_on_pending_queue_advanced(
|
||||
self, cb: Callable[[str], Awaitable[None]] | None
|
||||
) -> None:
|
||||
self._on_pending_queue_advanced = cb
|
||||
|
||||
async def _notify_pending_queue_advanced(self, surgery_id: str) -> None:
|
||||
cb = self._on_pending_queue_advanced
|
||||
if cb is None:
|
||||
return
|
||||
try:
|
||||
await cb(surgery_id)
|
||||
except Exception as exc:
|
||||
logger.warning("on_pending_queue_advanced 回调失败: {}", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TTS:保持对外接口不变
|
||||
@@ -240,6 +257,7 @@ class VoiceConfirmationService:
|
||||
chosen_label=chosen,
|
||||
rejected=rejected,
|
||||
)
|
||||
await self._notify_pending_queue_advanced(surgery_id)
|
||||
final_status = "rejected" if rejected else "recognized"
|
||||
await self._emitter.success(
|
||||
source="wav",
|
||||
@@ -344,6 +362,7 @@ class VoiceConfirmationService:
|
||||
chosen_label=chosen,
|
||||
rejected=rejected,
|
||||
)
|
||||
await self._notify_pending_queue_advanced(surgery_id)
|
||||
final_status = "rejected" if rejected else "recognized"
|
||||
await self._emitter.success(
|
||||
source="text",
|
||||
@@ -448,7 +467,9 @@ class VoiceConfirmationService:
|
||||
session_trace,
|
||||
) -> object:
|
||||
try:
|
||||
return await run_in_threadpool(self._baidu.asr, pcm, "pcm", 16000, None)
|
||||
return await run_in_threadpool(
|
||||
self._baidu.asr_16k_mono_pcm_or_wav_fallback, pcm
|
||||
)
|
||||
except BaiduSpeechNotConfiguredError as exc:
|
||||
raise await self._emitter.fail(
|
||||
source="wav",
|
||||
@@ -600,5 +621,6 @@ class VoiceConfirmationService:
|
||||
include_extra={
|
||||
"confirmation_id": confirmation_id,
|
||||
"retry_remaining": retry_remaining,
|
||||
"asr_text": text,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""语音桌面终端:assignment 状态、WebSocket 推送与 HTTP 轮询兜底。"""
|
||||
"""语音桌面终端:assignment 状态、WebSocket 推送与 HTTP 拉取兜底。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import WebSocket
|
||||
from loguru import logger
|
||||
@@ -14,6 +16,8 @@ from starlette.websockets import WebSocketDisconnect
|
||||
from app.config import Settings
|
||||
from app.services.voice_terminal_binding import VoiceTerminalBindingIndex
|
||||
|
||||
PendingHeadFetcher = Callable[[str], Awaitable[Any]]
|
||||
|
||||
|
||||
async def assign_voice_terminal_after_recording_started(
|
||||
hub: VoiceTerminalHub,
|
||||
@@ -45,12 +49,18 @@ async def assign_voice_terminal_after_recording_started(
|
||||
class VoiceTerminalHub:
|
||||
"""进程内终端连接与当前手术分配(多 worker 需另行同步)。"""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
*,
|
||||
pending_head_fetcher: PendingHeadFetcher | None = None,
|
||||
) -> None:
|
||||
cfg = settings.load_or_site_config()
|
||||
self._bindings = cfg.voice_bindings if cfg else None
|
||||
self._assignments: dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
self._connections: dict[str, set[WebSocket]] = defaultdict(set)
|
||||
self._pending_head_fetcher = pending_head_fetcher
|
||||
|
||||
@property
|
||||
def bindings(self) -> VoiceTerminalBindingIndex | None:
|
||||
@@ -81,6 +91,15 @@ class VoiceTerminalHub:
|
||||
tid,
|
||||
surgery_id,
|
||||
)
|
||||
self.schedule_notify_pending_head(tid, surgery_id)
|
||||
|
||||
def schedule_notify_pending_head(self, terminal_id: str, surgery_id: str) -> None:
|
||||
"""异步推送队首(含 TTS),不阻塞调用方。"""
|
||||
tid = terminal_id.strip()
|
||||
sid = (surgery_id or "").strip()
|
||||
if not tid or not sid:
|
||||
return
|
||||
asyncio.create_task(self._notify_pending_head_safe(tid, sid))
|
||||
|
||||
async def notify_end(self, terminal_id: str | None, surgery_id: str) -> None:
|
||||
if not terminal_id:
|
||||
@@ -103,6 +122,50 @@ class VoiceTerminalHub:
|
||||
surgery_id,
|
||||
)
|
||||
|
||||
async def notify_pending_head(self, terminal_id: str, surgery_id: str) -> None:
|
||||
"""向终端推送当前 FIFO 队首(含 TTS),无队首时推送 voice_pending_empty。"""
|
||||
fetcher = self._pending_head_fetcher
|
||||
tid = terminal_id.strip()
|
||||
sid = (surgery_id or "").strip()
|
||||
if not fetcher or not tid or not sid:
|
||||
return
|
||||
try:
|
||||
payload = await fetcher(sid)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"voice_pending 构建失败 surgery_id={} terminal_id={}: {}",
|
||||
sid,
|
||||
tid,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
if payload is None:
|
||||
await self._broadcast(
|
||||
tid,
|
||||
{"type": "voice_pending_empty", "surgery_id": sid},
|
||||
)
|
||||
return
|
||||
try:
|
||||
data = payload.model_dump(mode="json")
|
||||
except Exception as exc:
|
||||
logger.warning("voice_pending 序列化失败 surgery_id={}: {}", sid, exc)
|
||||
return
|
||||
data["type"] = "voice_pending"
|
||||
await self._broadcast(tid, data)
|
||||
|
||||
async def _notify_pending_head_safe(
|
||||
self, terminal_id: str, surgery_id: str
|
||||
) -> None:
|
||||
try:
|
||||
await self.notify_pending_head(terminal_id, surgery_id)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"后台 voice_pending 推送失败 terminal_id={} surgery_id={}: {}",
|
||||
terminal_id,
|
||||
surgery_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
async def handle_websocket(self, websocket: WebSocket, terminal_id: str) -> None:
|
||||
tid = terminal_id.strip()
|
||||
if not tid:
|
||||
@@ -125,6 +188,7 @@ class VoiceTerminalHub:
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
self.schedule_notify_pending_head(tid, sid)
|
||||
# 不能用 receive_text():桌面端 websocket-client 会发 ping/二进制控制帧,
|
||||
# ASGI 可能呈现为无 "text" 的 websocket.receive,receive_text 会 KeyError 并掐断连接。
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user