100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
"""Local faster-whisper ASR adapter — implements ASRProvider port."""
|
|
|
|
from app.core.logging import get_logger
|
|
import os
|
|
import tempfile
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
_DEFAULT_CACHE_DIR = os.path.normpath(
|
|
os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)),
|
|
"..",
|
|
"..",
|
|
"..",
|
|
"models",
|
|
"whisper",
|
|
)
|
|
)
|
|
|
|
|
|
class WhisperASRProvider:
|
|
def __init__(
|
|
self,
|
|
model_size: str = "small",
|
|
device: str = "auto",
|
|
compute_type: str = "auto",
|
|
cache_dir: str = "",
|
|
):
|
|
self._model_size = model_size
|
|
self._device = device
|
|
self._compute_type = compute_type
|
|
self._cache_dir = cache_dir
|
|
self._model = None
|
|
|
|
def _load_model(self) -> bool:
|
|
if self._model is not None:
|
|
return True
|
|
try:
|
|
from faster_whisper import WhisperModel
|
|
|
|
device = self._device
|
|
compute_type = self._compute_type
|
|
if device == "auto":
|
|
try:
|
|
import torch # type: ignore[import-untyped]
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
except ImportError:
|
|
device = "cpu"
|
|
if compute_type == "auto":
|
|
compute_type = "float16" if device == "cuda" else "int8"
|
|
|
|
download_root = self._cache_dir or _DEFAULT_CACHE_DIR
|
|
local_files_only = bool(self._cache_dir)
|
|
os.makedirs(download_root, exist_ok=True)
|
|
|
|
self._model = WhisperModel(
|
|
self._model_size,
|
|
device=device,
|
|
compute_type=compute_type,
|
|
download_root=download_root,
|
|
local_files_only=local_files_only,
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error("Failed to load Whisper model: %s", e)
|
|
return False
|
|
|
|
def ensure_ready(self) -> bool:
|
|
return self._load_model()
|
|
|
|
async def transcribe(self, audio: bytes, format: str = "m4a") -> str:
|
|
self._load_model()
|
|
if not self._model:
|
|
return ""
|
|
|
|
tmp_path = None
|
|
try:
|
|
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp:
|
|
tmp.write(audio)
|
|
tmp_path = tmp.name
|
|
|
|
segments, _info = self._model.transcribe(
|
|
tmp_path,
|
|
language="zh",
|
|
beam_size=5,
|
|
vad_filter=True,
|
|
vad_parameters={"min_silence_duration_ms": 500},
|
|
)
|
|
return "".join(seg.text for seg in segments).strip()
|
|
except Exception as e:
|
|
logger.error("Whisper transcribe failed: %s", e)
|
|
return ""
|
|
finally:
|
|
if tmp_path and os.path.exists(tmp_path):
|
|
try:
|
|
os.remove(tmp_path)
|
|
except OSError:
|
|
pass
|