Files
life-echo/api/services/asr_service.py
iammm0 0b7bd37d5d feat: Docker 构建预置 ASR 模型,支持离线使用
- Dockerfile 构建时预下载 faster-whisper 模型到镜像
- docker-compose 增加 ASR_MODEL_CACHE_DIR 环境变量
- asr_service 支持从缓存目录加载本地模型,无需运行时联网下载

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-10 15:08:00 +08:00

160 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ASR 服务:语音转文字
使用本地 faster-whisper 模型进行语音识别
"""
import base64
import logging
import os
import tempfile
from typing import Optional
logger = logging.getLogger(__name__)
# 模型配置
# 可选模型: tiny, base, small, medium, large-v2, large-v3
# tiny/base 适合 CPUsmall/medium 需要更多资源large 需要 GPU
ASR_MODEL_SIZE = os.getenv("ASR_MODEL_SIZE", "small")
ASR_DEVICE = os.getenv("ASR_DEVICE", "auto") # auto, cpu, cuda
ASR_COMPUTE_TYPE = os.getenv("ASR_COMPUTE_TYPE", "auto") # auto, int8, float16, float32
# 镜像内预置模型目录,设置后直接使用本地模型不联网下载(与 Dockerfile 中 download_root 一致)
ASR_MODEL_CACHE_DIR = os.getenv("ASR_MODEL_CACHE_DIR")
class ASRService:
"""
ASR 服务(语音转文字)
使用 faster-whisper 本地模型
"""
def __init__(self):
self.model = None
self._model_loaded = False
self._load_error = None
def _load_model(self) -> bool:
"""加载模型(首次调用时执行,后续直接返回)。返回是否加载成功。"""
if self._model_loaded:
return self.model is not None
try:
from faster_whisper import WhisperModel
logger.info(f"正在加载 Whisper 模型: {ASR_MODEL_SIZE}, device={ASR_DEVICE}, compute_type={ASR_COMPUTE_TYPE}")
# 确定设备和计算类型
device = ASR_DEVICE
compute_type = ASR_COMPUTE_TYPE
if device == "auto":
# 自动检测:优先使用 CUDA否则使用 CPU
try:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
device = "cpu"
if compute_type == "auto":
# 根据设备自动选择计算类型
if device == "cuda":
compute_type = "float16" # GPU 使用 float16
else:
compute_type = "int8" # CPU 使用 int8 量化,速度更快
download_root = ASR_MODEL_CACHE_DIR if ASR_MODEL_CACHE_DIR else None
local_files_only = bool(ASR_MODEL_CACHE_DIR)
self.model = WhisperModel(
ASR_MODEL_SIZE,
device=device,
compute_type=compute_type,
download_root=download_root,
local_files_only=local_files_only,
)
self._model_loaded = True
logger.info(f"Whisper 模型加载成功: {ASR_MODEL_SIZE} on {device} ({compute_type})")
return True
except ImportError as e:
self._load_error = "faster-whisper 未安装,请运行: pip install faster-whisper"
logger.error(self._load_error)
return False
except Exception as e:
self._load_error = f"加载 Whisper 模型失败: {str(e)}"
logger.error(self._load_error, exc_info=True)
return False
def ensure_ready(self) -> bool:
"""
确保 ASR 模型已就绪(用于启动时预加载与检查)。
可在应用初始化时调用;为同步阻塞调用,建议在后台线程执行。
返回是否就绪。
"""
return self._load_model()
def is_ready(self) -> bool:
"""检查 ASR 模型是否已加载并可用。"""
return self._model_loaded and self.model is not None
async def transcribe(self, audio_base64: str) -> Optional[str]:
"""
转写音频为文字
Args:
audio_base64: Base64 编码的音频数据
Returns:
转写文本,失败时返回错误信息
"""
# 懒加载模型
self._load_model()
if not self.model:
error_msg = self._load_error or "ASR 模型未加载"
logger.warning(error_msg)
return f"转写失败: {error_msg}"
tmp_file_path = None
try:
# 解码 Base64 音频
audio_bytes = base64.b64decode(audio_base64)
# 保存临时文件
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp_file:
tmp_file.write(audio_bytes)
tmp_file_path = tmp_file.name
# 使用 faster-whisper 转写
# language="zh" 指定中文,可以提高识别速度
# beam_size=5 是默认值,可以调整
segments, info = self.model.transcribe(
tmp_file_path,
language="zh",
beam_size=5,
vad_filter=True, # 启用 VAD 过滤静音部分
vad_parameters=dict(
min_silence_duration_ms=500, # 最小静音时长
)
)
# 合并所有转写片段
transcript_text = "".join(segment.text for segment in segments)
logger.info(f"ASR 转写完成: 语言={info.language}, 概率={info.language_probability:.2f}, 文本长度={len(transcript_text)}")
return transcript_text.strip() if transcript_text else ""
except Exception as e:
logger.error(f"ASR 转写失败: {e}", exc_info=True)
return f"转写失败: {str(e)}"
finally:
# 清理临时文件
if tmp_file_path and os.path.exists(tmp_file_path):
try:
os.remove(tmp_file_path)
except Exception:
pass
# 全局实例
asr_service = ASRService()