Files
life-echo/api/app/adapters/asr/whisper_local.py

92 lines
2.8 KiB
Python

"""Local faster-whisper ASR adapter — implements ASRProvider port."""
from app.core.logging import get_logger
import os
import tempfile
logger = get_logger(__name__)
_DEFAULT_CACHE_DIR = os.path.normpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "models", "whisper")
)
class WhisperASRProvider:
def __init__(
self,
model_size: str = "small",
device: str = "auto",
compute_type: str = "auto",
cache_dir: str = "",
):
self._model_size = model_size
self._device = device
self._compute_type = compute_type
self._cache_dir = cache_dir
self._model = None
def _load_model(self) -> bool:
if self._model is not None:
return True
try:
from faster_whisper import WhisperModel
device = self._device
compute_type = self._compute_type
if device == "auto":
try:
import torch # type: ignore[import-untyped]
device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
device = "cpu"
if compute_type == "auto":
compute_type = "float16" if device == "cuda" else "int8"
download_root = self._cache_dir or _DEFAULT_CACHE_DIR
local_files_only = bool(self._cache_dir)
os.makedirs(download_root, exist_ok=True)
self._model = WhisperModel(
self._model_size,
device=device,
compute_type=compute_type,
download_root=download_root,
local_files_only=local_files_only,
)
return True
except Exception as e:
logger.error("Failed to load Whisper model: %s", e)
return False
def ensure_ready(self) -> bool:
return self._load_model()
async def transcribe(self, audio: bytes, format: str = "m4a") -> str:
self._load_model()
if not self._model:
return ""
tmp_path = None
try:
with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp:
tmp.write(audio)
tmp_path = tmp.name
segments, _info = self._model.transcribe(
tmp_path,
language="zh",
beam_size=5,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
return "".join(seg.text for seg in segments).strip()
except Exception as e:
logger.error("Whisper transcribe failed: %s", e)
return ""
finally:
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except OSError:
pass