"""Local faster-whisper ASR adapter — implements ASRProvider port.""" from app.core.logging import get_logger import os import tempfile logger = get_logger(__name__) _DEFAULT_CACHE_DIR = os.path.normpath( os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "models", "whisper", ) ) class WhisperASRProvider: def __init__( self, model_size: str = "small", device: str = "auto", compute_type: str = "auto", cache_dir: str = "", ): self._model_size = model_size self._device = device self._compute_type = compute_type self._cache_dir = cache_dir self._model = None def _load_model(self) -> bool: if self._model is not None: return True try: from faster_whisper import WhisperModel device = self._device compute_type = self._compute_type if device == "auto": try: import torch # type: ignore[import-untyped] device = "cuda" if torch.cuda.is_available() else "cpu" except ImportError: device = "cpu" if compute_type == "auto": compute_type = "float16" if device == "cuda" else "int8" download_root = self._cache_dir or _DEFAULT_CACHE_DIR local_files_only = bool(self._cache_dir) os.makedirs(download_root, exist_ok=True) self._model = WhisperModel( self._model_size, device=device, compute_type=compute_type, download_root=download_root, local_files_only=local_files_only, ) return True except Exception as e: logger.error("Failed to load Whisper model: %s", e) return False def ensure_ready(self) -> bool: return self._load_model() async def transcribe(self, audio: bytes, format: str = "m4a") -> str: self._load_model() if not self._model: return "" tmp_path = None try: with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp: tmp.write(audio) tmp_path = tmp.name segments, _info = self._model.transcribe( tmp_path, language="zh", beam_size=5, vad_filter=True, vad_parameters={"min_silence_duration_ms": 500}, ) return "".join(seg.text for seg in segments).strip() except Exception as e: logger.error("Whisper transcribe failed: %s", e) return "" finally: if tmp_path and os.path.exists(tmp_path): try: os.remove(tmp_path) except OSError: pass