148 lines
4.3 KiB
Python
148 lines
4.3 KiB
Python
"""Decode WAV bytes to 16 kHz mono 16-bit PCM for Baidu short ASR."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import array
|
||
import io
|
||
import shutil
|
||
import subprocess
|
||
import wave
|
||
from typing import Final
|
||
|
||
_BAIDU_RATE: Final[int] = 16000
|
||
# 诊室麦克风常见音量偏小,百度短语音 3301「语音质量错误」多与有效幅度过低有关。
|
||
_NORM_TARGET_PEAK: Final[int] = 12000
|
||
_NORM_MAX_GAIN: Final[float] = 80.0
|
||
|
||
|
||
class WavDecodeError(ValueError):
|
||
"""Uploaded bytes are not a valid WAV or cannot be converted."""
|
||
|
||
|
||
def pcm_s16le_to_wav_bytes(pcm: bytes, *, sample_rate: int = _BAIDU_RATE) -> bytes:
|
||
"""将 raw s16le mono PCM 打成标准 WAV,供百度 ``format=wav`` 重试等场景。"""
|
||
if not pcm:
|
||
raise WavDecodeError("Empty PCM")
|
||
buf = io.BytesIO()
|
||
with wave.open(buf, "wb") as wf:
|
||
wf.setnchannels(1)
|
||
wf.setsampwidth(2)
|
||
wf.setframerate(sample_rate)
|
||
wf.writeframes(pcm)
|
||
return buf.getvalue()
|
||
|
||
|
||
def normalize_pcm_s16le_for_baidu(pcm: bytes) -> bytes:
|
||
"""提升过弱信号幅度,降低 err_no=3301(speech quality)概率;已足够响的音频不改。"""
|
||
if len(pcm) < 2 or len(pcm) % 2 != 0:
|
||
return pcm
|
||
samples = array.array("h")
|
||
samples.frombytes(pcm)
|
||
if not samples:
|
||
return pcm
|
||
peak = 0
|
||
for s in samples:
|
||
a = abs(int(s))
|
||
if a > peak:
|
||
peak = a
|
||
if peak == 0 or peak >= _NORM_TARGET_PEAK:
|
||
return pcm
|
||
scale = min(_NORM_MAX_GAIN, float(_NORM_TARGET_PEAK) / float(peak))
|
||
if scale <= 1.0:
|
||
return pcm
|
||
out = array.array("h")
|
||
for s in samples:
|
||
v = int(round(float(s) * scale))
|
||
if v > 32767:
|
||
v = 32767
|
||
elif v < -32768:
|
||
v = -32768
|
||
out.append(v)
|
||
return out.tobytes()
|
||
|
||
|
||
def wav_bytes_to_pcm16k_mono_s16le(wav_bytes: bytes) -> bytes:
|
||
"""
|
||
Prefer ffmpeg for arbitrary channel count / sample rate.
|
||
Falls back to stdlib `wave` when ffmpeg is unavailable (16-bit PCM only).
|
||
"""
|
||
if not wav_bytes:
|
||
raise WavDecodeError("Empty audio payload")
|
||
|
||
ffmpeg = shutil.which("ffmpeg")
|
||
if ffmpeg:
|
||
return _ffmpeg_to_pcm16k(wav_bytes, ffmpeg)
|
||
|
||
return _stdlib_wave_to_pcm16k(wav_bytes)
|
||
|
||
|
||
def _ffmpeg_to_pcm16k(wav_bytes: bytes, ffmpeg: str) -> bytes:
|
||
proc = subprocess.run(
|
||
[
|
||
ffmpeg,
|
||
"-nostdin",
|
||
"-loglevel",
|
||
"error",
|
||
"-i",
|
||
"pipe:0",
|
||
"-f",
|
||
"s16le",
|
||
"-ac",
|
||
"1",
|
||
"-ar",
|
||
str(_BAIDU_RATE),
|
||
"pipe:1",
|
||
],
|
||
input=wav_bytes,
|
||
capture_output=True,
|
||
timeout=120,
|
||
check=False,
|
||
)
|
||
if proc.returncode != 0:
|
||
err = (proc.stderr or b"").decode("utf-8", errors="replace")
|
||
raise WavDecodeError(f"ffmpeg wav decode failed: {err or proc.returncode}")
|
||
if not proc.stdout:
|
||
raise WavDecodeError("ffmpeg produced empty PCM")
|
||
return normalize_pcm_s16le_for_baidu(proc.stdout)
|
||
|
||
|
||
def _stdlib_wave_to_pcm16k(wav_bytes: bytes) -> bytes:
|
||
try:
|
||
with wave.open(io.BytesIO(wav_bytes), "rb") as wf:
|
||
nchannels = wf.getnchannels()
|
||
sampwidth = wf.getsampwidth()
|
||
framerate = wf.getframerate()
|
||
nframes = wf.getnframes()
|
||
raw = wf.readframes(nframes)
|
||
except wave.Error as exc:
|
||
raise WavDecodeError(f"Invalid WAV: {exc}") from exc
|
||
|
||
if sampwidth != 2:
|
||
raise WavDecodeError(
|
||
f"WAV sample width {sampwidth * 8} bit not supported without ffmpeg"
|
||
)
|
||
if nchannels not in (1, 2):
|
||
raise WavDecodeError(
|
||
f"WAV channels={nchannels} not supported without ffmpeg"
|
||
)
|
||
if framerate != _BAIDU_RATE:
|
||
raise WavDecodeError(
|
||
f"WAV rate {framerate} requires ffmpeg for resampling to {_BAIDU_RATE} Hz"
|
||
)
|
||
|
||
if nchannels == 2:
|
||
# de-interleave stereo s16le -> mono average
|
||
import struct
|
||
|
||
out = bytearray()
|
||
for i in range(0, len(raw), 4):
|
||
chunk = raw[i : i + 4]
|
||
if len(chunk) < 4:
|
||
break
|
||
l_s, r_s = struct.unpack("<hh", chunk)
|
||
m = max(min((l_s + r_s) // 2, 32767), -32768)
|
||
out.extend(struct.pack("<h", m))
|
||
return normalize_pcm_s16le_for_baidu(bytes(out))
|
||
|
||
return normalize_pcm_s16le_for_baidu(raw)
|