2026-01-07 11:56:46 +08:00
|
|
|
|
"""
|
|
|
|
|
|
ASR 服务:语音转文字
|
2026-02-03 11:29:38 +08:00
|
|
|
|
使用本地 faster-whisper 模型进行语音识别
|
2026-01-07 11:56:46 +08:00
|
|
|
|
"""
|
|
|
|
|
|
import base64
|
2026-02-03 11:29:38 +08:00
|
|
|
|
import logging
|
2026-01-18 15:57:54 +08:00
|
|
|
|
import os
|
2026-02-03 11:29:38 +08:00
|
|
|
|
import tempfile
|
|
|
|
|
|
from typing import Optional
|
2026-01-18 15:57:54 +08:00
|
|
|
|
|
2026-02-03 11:29:38 +08:00
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
# 模型配置
|
|
|
|
|
|
# 可选模型: tiny, base, small, medium, large-v2, large-v3
|
|
|
|
|
|
# tiny/base 适合 CPU,small/medium 需要更多资源,large 需要 GPU
|
|
|
|
|
|
ASR_MODEL_SIZE = os.getenv("ASR_MODEL_SIZE", "small")
|
|
|
|
|
|
ASR_DEVICE = os.getenv("ASR_DEVICE", "auto") # auto, cpu, cuda
|
|
|
|
|
|
ASR_COMPUTE_TYPE = os.getenv("ASR_COMPUTE_TYPE", "auto") # auto, int8, float16, float32
|
2026-02-11 16:06:06 +08:00
|
|
|
|
# 模型缓存目录:每次启动优先从该目录加载,不设置则使用默认本地路径(api/models/whisper)
|
|
|
|
|
|
# 设置 ASR_MODEL_CACHE_DIR 时仅使用本地模型不联网(与 Dockerfile 中 download_root 一致)
|
2026-02-10 15:08:00 +08:00
|
|
|
|
ASR_MODEL_CACHE_DIR = os.getenv("ASR_MODEL_CACHE_DIR")
|
2026-02-11 16:06:06 +08:00
|
|
|
|
# 默认本地缓存目录(相对 api 目录),确保每次启动都先从本地加载
|
|
|
|
|
|
_DEFAULT_ASR_CACHE_DIR = os.path.normpath(
|
|
|
|
|
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "models", "whisper")
|
|
|
|
|
|
)
|
2026-01-07 11:56:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-02-11 16:06:06 +08:00
|
|
|
|
class WhisperASRService:
|
2026-02-03 11:29:38 +08:00
|
|
|
|
"""
|
|
|
|
|
|
ASR 服务(语音转文字)
|
|
|
|
|
|
使用 faster-whisper 本地模型
|
|
|
|
|
|
"""
|
2026-01-07 11:56:46 +08:00
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
2026-02-03 11:29:38 +08:00
|
|
|
|
self.model = None
|
|
|
|
|
|
self._model_loaded = False
|
|
|
|
|
|
self._load_error = None
|
|
|
|
|
|
|
|
|
|
|
|
def _load_model(self) -> bool:
|
|
|
|
|
|
"""加载模型(首次调用时执行,后续直接返回)。返回是否加载成功。"""
|
|
|
|
|
|
if self._model_loaded:
|
|
|
|
|
|
return self.model is not None
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from faster_whisper import WhisperModel
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"正在加载 Whisper 模型: {ASR_MODEL_SIZE}, device={ASR_DEVICE}, compute_type={ASR_COMPUTE_TYPE}")
|
|
|
|
|
|
|
|
|
|
|
|
# 确定设备和计算类型
|
|
|
|
|
|
device = ASR_DEVICE
|
|
|
|
|
|
compute_type = ASR_COMPUTE_TYPE
|
|
|
|
|
|
|
|
|
|
|
|
if device == "auto":
|
|
|
|
|
|
# 自动检测:优先使用 CUDA,否则使用 CPU
|
|
|
|
|
|
try:
|
|
|
|
|
|
import torch
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
device = "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
if compute_type == "auto":
|
|
|
|
|
|
# 根据设备自动选择计算类型
|
|
|
|
|
|
if device == "cuda":
|
|
|
|
|
|
compute_type = "float16" # GPU 使用 float16
|
|
|
|
|
|
else:
|
|
|
|
|
|
compute_type = "int8" # CPU 使用 int8 量化,速度更快
|
|
|
|
|
|
|
2026-02-11 16:06:06 +08:00
|
|
|
|
# 每次启动都先从本地目录加载:优先用环境变量,否则用默认 api/models/whisper
|
|
|
|
|
|
download_root = ASR_MODEL_CACHE_DIR if ASR_MODEL_CACHE_DIR else _DEFAULT_ASR_CACHE_DIR
|
|
|
|
|
|
local_files_only = bool(ASR_MODEL_CACHE_DIR) # 仅当显式设置缓存目录时禁止联网(如 Docker)
|
|
|
|
|
|
if not os.path.isdir(download_root):
|
|
|
|
|
|
os.makedirs(download_root, exist_ok=True)
|
|
|
|
|
|
logger.info(f"Whisper 模型从本地加载: download_root={download_root}, local_files_only={local_files_only}")
|
2026-02-03 11:29:38 +08:00
|
|
|
|
self.model = WhisperModel(
|
|
|
|
|
|
ASR_MODEL_SIZE,
|
|
|
|
|
|
device=device,
|
2026-02-10 15:08:00 +08:00
|
|
|
|
compute_type=compute_type,
|
|
|
|
|
|
download_root=download_root,
|
|
|
|
|
|
local_files_only=local_files_only,
|
2026-02-03 11:29:38 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self._model_loaded = True
|
|
|
|
|
|
logger.info(f"Whisper 模型加载成功: {ASR_MODEL_SIZE} on {device} ({compute_type})")
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
|
self._load_error = "faster-whisper 未安装,请运行: pip install faster-whisper"
|
|
|
|
|
|
logger.error(self._load_error)
|
|
|
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self._load_error = f"加载 Whisper 模型失败: {str(e)}"
|
|
|
|
|
|
logger.error(self._load_error, exc_info=True)
|
|
|
|
|
|
return False
|
2026-01-07 11:56:46 +08:00
|
|
|
|
|
2026-02-03 11:29:38 +08:00
|
|
|
|
def ensure_ready(self) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
确保 ASR 模型已就绪(用于启动时预加载与检查)。
|
|
|
|
|
|
可在应用初始化时调用;为同步阻塞调用,建议在后台线程执行。
|
|
|
|
|
|
返回是否就绪。
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self._load_model()
|
|
|
|
|
|
|
|
|
|
|
|
def is_ready(self) -> bool:
|
|
|
|
|
|
"""检查 ASR 模型是否已加载并可用。"""
|
|
|
|
|
|
return self._model_loaded and self.model is not None
|
|
|
|
|
|
|
|
|
|
|
|
async def transcribe(self, audio_base64: str) -> Optional[str]:
|
2026-01-07 11:56:46 +08:00
|
|
|
|
"""
|
|
|
|
|
|
转写音频为文字
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
audio_base64: Base64 编码的音频数据
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-02-03 11:29:38 +08:00
|
|
|
|
转写文本,失败时返回错误信息
|
2026-01-07 11:56:46 +08:00
|
|
|
|
"""
|
2026-02-03 11:29:38 +08:00
|
|
|
|
# 懒加载模型
|
|
|
|
|
|
self._load_model()
|
|
|
|
|
|
|
|
|
|
|
|
if not self.model:
|
|
|
|
|
|
error_msg = self._load_error or "ASR 模型未加载"
|
|
|
|
|
|
logger.warning(error_msg)
|
|
|
|
|
|
return f"转写失败: {error_msg}"
|
2026-01-07 11:56:46 +08:00
|
|
|
|
|
2026-02-03 11:29:38 +08:00
|
|
|
|
tmp_file_path = None
|
2026-01-07 11:56:46 +08:00
|
|
|
|
try:
|
|
|
|
|
|
# 解码 Base64 音频
|
|
|
|
|
|
audio_bytes = base64.b64decode(audio_base64)
|
|
|
|
|
|
|
|
|
|
|
|
# 保存临时文件
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp_file:
|
|
|
|
|
|
tmp_file.write(audio_bytes)
|
|
|
|
|
|
tmp_file_path = tmp_file.name
|
|
|
|
|
|
|
2026-02-03 11:29:38 +08:00
|
|
|
|
# 使用 faster-whisper 转写
|
|
|
|
|
|
# language="zh" 指定中文,可以提高识别速度
|
|
|
|
|
|
# beam_size=5 是默认值,可以调整
|
|
|
|
|
|
segments, info = self.model.transcribe(
|
|
|
|
|
|
tmp_file_path,
|
|
|
|
|
|
language="zh",
|
|
|
|
|
|
beam_size=5,
|
|
|
|
|
|
vad_filter=True, # 启用 VAD 过滤静音部分
|
|
|
|
|
|
vad_parameters=dict(
|
|
|
|
|
|
min_silence_duration_ms=500, # 最小静音时长
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 合并所有转写片段
|
|
|
|
|
|
transcript_text = "".join(segment.text for segment in segments)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"ASR 转写完成: 语言={info.language}, 概率={info.language_probability:.2f}, 文本长度={len(transcript_text)}")
|
|
|
|
|
|
|
|
|
|
|
|
return transcript_text.strip() if transcript_text else ""
|
|
|
|
|
|
|
2026-01-07 11:56:46 +08:00
|
|
|
|
except Exception as e:
|
2026-02-03 11:29:38 +08:00
|
|
|
|
logger.error(f"ASR 转写失败: {e}", exc_info=True)
|
2026-01-07 11:56:46 +08:00
|
|
|
|
return f"转写失败: {str(e)}"
|
2026-02-03 11:29:38 +08:00
|
|
|
|
finally:
|
|
|
|
|
|
# 清理临时文件
|
|
|
|
|
|
if tmp_file_path and os.path.exists(tmp_file_path):
|
|
|
|
|
|
try:
|
|
|
|
|
|
os.remove(tmp_file_path)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
2026-01-07 11:56:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全局实例
|
2026-02-11 16:06:06 +08:00
|
|
|
|
asr_service = WhisperASRService()
|