"""Decode WAV bytes to 16 kHz mono 16-bit PCM for Baidu short ASR.""" from __future__ import annotations import io import shutil import subprocess import wave from typing import Final _BAIDU_RATE: Final[int] = 16000 class WavDecodeError(ValueError): """Uploaded bytes are not a valid WAV or cannot be converted.""" def wav_bytes_to_pcm16k_mono_s16le(wav_bytes: bytes) -> bytes: """ Prefer ffmpeg for arbitrary channel count / sample rate. Falls back to stdlib `wave` when ffmpeg is unavailable (16-bit PCM only). """ if not wav_bytes: raise WavDecodeError("Empty audio payload") ffmpeg = shutil.which("ffmpeg") if ffmpeg: return _ffmpeg_to_pcm16k(wav_bytes, ffmpeg) return _stdlib_wave_to_pcm16k(wav_bytes) def _ffmpeg_to_pcm16k(wav_bytes: bytes, ffmpeg: str) -> bytes: proc = subprocess.run( [ ffmpeg, "-nostdin", "-loglevel", "error", "-i", "pipe:0", "-f", "s16le", "-ac", "1", "-ar", str(_BAIDU_RATE), "pipe:1", ], input=wav_bytes, capture_output=True, timeout=120, check=False, ) if proc.returncode != 0: err = (proc.stderr or b"").decode("utf-8", errors="replace") raise WavDecodeError(f"ffmpeg wav decode failed: {err or proc.returncode}") if not proc.stdout: raise WavDecodeError("ffmpeg produced empty PCM") return proc.stdout def _stdlib_wave_to_pcm16k(wav_bytes: bytes) -> bytes: try: with wave.open(io.BytesIO(wav_bytes), "rb") as wf: nchannels = wf.getnchannels() sampwidth = wf.getsampwidth() framerate = wf.getframerate() nframes = wf.getnframes() raw = wf.readframes(nframes) except wave.Error as exc: raise WavDecodeError(f"Invalid WAV: {exc}") from exc if sampwidth != 2: raise WavDecodeError( f"WAV sample width {sampwidth * 8} bit not supported without ffmpeg" ) if nchannels not in (1, 2): raise WavDecodeError( f"WAV channels={nchannels} not supported without ffmpeg" ) if framerate != _BAIDU_RATE: raise WavDecodeError( f"WAV rate {framerate} requires ffmpeg for resampling to {_BAIDU_RATE} Hz" ) if nchannels == 2: # de-interleave stereo s16le -> mono average import struct out = bytearray() for i in range(0, len(raw), 4): chunk = raw[i : i + 4] if len(chunk) < 4: break l_s, r_s = struct.unpack("