diff --git a/api/.env.production b/api/.env.production index 246982e..5222e71 100644 --- a/api/.env.production +++ b/api/.env.production @@ -46,4 +46,14 @@ TENCENT_SMS_TEMPLATE_ID=2592163 # 短信模板参数数量(1=仅验证码,2=验证码+过期时间) # 如果遇到 TemplateParamSetNotMatchApprovedTemplate 错误,请检查腾讯云控制台中的模板配置 # 并根据实际模板参数数量设置此值 -TENCENT_SMS_TEMPLATE_PARAM_COUNT=1 \ No newline at end of file +TENCENT_SMS_TEMPLATE_PARAM_COUNT=1 + +# CPU 环境(推荐 small + int8) +ASR_MODEL_SIZE=small +ASR_DEVICE=cpu +ASR_COMPUTE_TYPE=int8 + +# GPU 环境(推荐 medium + float16) +# ASR_MODEL_SIZE=medium +# ASR_DEVICE=cuda +# ASR_COMPUTE_TYPE=float16 \ No newline at end of file diff --git a/api/main.py b/api/main.py index 50ad957..efee465 100644 --- a/api/main.py +++ b/api/main.py @@ -94,6 +94,7 @@ app = FastAPI(title="Life Echo API", version="1.0.0") @app.on_event("startup") async def startup_event(): """应用启动事件""" + import asyncio logger.info("=" * 50) logger.info("Life Echo API 正在启动...") logger.info("=" * 50) @@ -105,6 +106,17 @@ async def startup_event(): logger.info("Redis 连接已建立") except Exception as e: logger.warning(f"Redis 连接失败(会话存储将不可用): {e}") + + # 检查并预加载 ASR 模型(在后台线程执行,避免阻塞启动) + try: + from services.asr_service import asr_service + asr_ready = await asyncio.to_thread(asr_service.ensure_ready) + if asr_ready: + logger.info("ASR 模型已就绪(本地 Whisper)") + else: + logger.warning("ASR 模型未就绪,语音转写将不可用") + except Exception as e: + logger.warning(f"ASR 初始化检查失败: {e}") @app.on_event("shutdown") diff --git a/api/requirements.txt b/api/requirements.txt index 6d62d1d..8029e2d 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -38,9 +38,8 @@ httpx==0.27.0 python-jose[cryptography]==3.3.0 bcrypt>=4.0.0 -# Audio Processing (optional, for future ASR/TTS integration) -# pydub==0.25.1 -# speech-recognition==3.10.4 +# Audio Processing - Local Whisper ASR +faster-whisper>=1.0.0 # Image Processing Pillow>=10.0.0 diff --git a/api/routers/websocket.py b/api/routers/websocket.py index c0ddfbf..a726ebe 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -19,6 +19,7 @@ from database.models import Conversation, Segment from database.models import User as UserModel from services.auth_service import verify_token from services.memoir_state_service import get_or_create_state +from services.asr_service import asr_service from fastapi import HTTPException, status logger = logging.getLogger(__name__) @@ -28,8 +29,9 @@ class MessageType(str, Enum): """WebSocket 消息类型""" CONNECT = "connect" AUDIO_CHUNK = "audio_chunk" + AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音) TEXT = "text" # 文本消息 - TRANSCRIPT = "transcript" + TRANSCRIPT = "transcript" # 语音转文字结果 AGENT_RESPONSE = "agent_response" TTS_AUDIO = "tts_audio" END_CONVERSATION = "end_conversation" @@ -190,6 +192,70 @@ async def websocket_endpoint( manager=manager ) + elif msg_type == MessageType.AUDIO_MESSAGE: + # 处理完整音频消息(类似微信语音) + data = message.get("data", {}) + audio_base64 = data.get("audio_base64", "") + audio_duration = data.get("duration", 0) + + if audio_base64: + logger.info(f"收到音频消息,时长: {audio_duration}s") + + try: + # 1. ASR 转写 + transcript_text = await asr_service.transcribe(audio_base64) + logger.info(f"ASR 转写结果: {transcript_text}") + + # 2. 发送转写结果给客户端 + await manager.send_message(conversation_id, { + "type": MessageType.TRANSCRIPT, + "conversation_id": conversation_id, + "data": { + "text": transcript_text, + "audio_duration": audio_duration + }, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + # 3. 保存段落到数据库(包含转写文本和音频信息) + segment = Segment( + id=str(uuid.uuid4()), + conversation_id=conversation_id, + transcript_text=transcript_text, + audio_url=f"audio:{audio_duration}s", # 简化存储,标记为音频消息 + processed=False + ) + db.add(segment) + await db.commit() + await db.refresh(segment) + await manager.background_runner.queue_message(conversation.user_id, segment.id) + + # 4. Agent 生成回应(基于转写文本) + if transcript_text and not transcript_text.startswith("转写失败"): + await process_user_message( + conversation_id=conversation_id, + user_message=transcript_text, + conversation=conversation, + segment=segment, + db=db, + manager=manager + ) + else: + # 转写失败,发送错误消息 + await manager.send_message(conversation_id, { + "type": MessageType.ERROR, + "data": {"message": "语音转写失败,请重试或使用文字输入"}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + except Exception as e: + logger.error(f"处理音频消息失败: {e}", exc_info=True) + await manager.send_message(conversation_id, { + "type": MessageType.ERROR, + "data": {"message": f"处理音频消息失败: {str(e)}"}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + elif msg_type == MessageType.END_CONVERSATION: # 结束对话 conversation.status = "ended" diff --git a/api/services/asr_service.py b/api/services/asr_service.py index df7a063..36cab03 100644 --- a/api/services/asr_service.py +++ b/api/services/asr_service.py @@ -1,23 +1,95 @@ """ ASR 服务:语音转文字 +使用本地 faster-whisper 模型进行语音识别 """ import base64 +import logging import os +import tempfile +from typing import Optional -from openai import OpenAI +logger = logging.getLogger(__name__) + +# 模型配置 +# 可选模型: tiny, base, small, medium, large-v2, large-v3 +# tiny/base 适合 CPU,small/medium 需要更多资源,large 需要 GPU +ASR_MODEL_SIZE = os.getenv("ASR_MODEL_SIZE", "small") +ASR_DEVICE = os.getenv("ASR_DEVICE", "auto") # auto, cpu, cuda +ASR_COMPUTE_TYPE = os.getenv("ASR_COMPUTE_TYPE", "auto") # auto, int8, float16, float32 class ASRService: - """ASR 服务(语音转文字)""" + """ + ASR 服务(语音转文字) + 使用 faster-whisper 本地模型 + """ def __init__(self): - api_key = os.getenv("OPENAI_API_KEY", "") - if api_key: - self.client = OpenAI(api_key=api_key) - else: - self.client = None + self.model = None + self._model_loaded = False + self._load_error = None - async def transcribe(self, audio_base64: str) -> str | None: + def _load_model(self) -> bool: + """加载模型(首次调用时执行,后续直接返回)。返回是否加载成功。""" + if self._model_loaded: + return self.model is not None + + try: + from faster_whisper import WhisperModel + + logger.info(f"正在加载 Whisper 模型: {ASR_MODEL_SIZE}, device={ASR_DEVICE}, compute_type={ASR_COMPUTE_TYPE}") + + # 确定设备和计算类型 + device = ASR_DEVICE + compute_type = ASR_COMPUTE_TYPE + + if device == "auto": + # 自动检测:优先使用 CUDA,否则使用 CPU + try: + import torch + device = "cuda" if torch.cuda.is_available() else "cpu" + except ImportError: + device = "cpu" + + if compute_type == "auto": + # 根据设备自动选择计算类型 + if device == "cuda": + compute_type = "float16" # GPU 使用 float16 + else: + compute_type = "int8" # CPU 使用 int8 量化,速度更快 + + self.model = WhisperModel( + ASR_MODEL_SIZE, + device=device, + compute_type=compute_type + ) + + self._model_loaded = True + logger.info(f"Whisper 模型加载成功: {ASR_MODEL_SIZE} on {device} ({compute_type})") + return True + + except ImportError as e: + self._load_error = "faster-whisper 未安装,请运行: pip install faster-whisper" + logger.error(self._load_error) + return False + except Exception as e: + self._load_error = f"加载 Whisper 模型失败: {str(e)}" + logger.error(self._load_error, exc_info=True) + return False + + def ensure_ready(self) -> bool: + """ + 确保 ASR 模型已就绪(用于启动时预加载与检查)。 + 可在应用初始化时调用;为同步阻塞调用,建议在后台线程执行。 + 返回是否就绪。 + """ + return self._load_model() + + def is_ready(self) -> bool: + """检查 ASR 模型是否已加载并可用。""" + return self._model_loaded and self.model is not None + + async def transcribe(self, audio_base64: str) -> Optional[str]: """ 转写音频为文字 @@ -25,39 +97,56 @@ class ASRService: audio_base64: Base64 编码的音频数据 Returns: - 转写文本 + 转写文本,失败时返回错误信息 """ - if not self.client: - # 如果没有配置 API Key,返回模拟数据 - return "这是模拟的转写文本(请配置 OPENAI_API_KEY 以使用实际 ASR 功能)" + # 懒加载模型 + self._load_model() + if not self.model: + error_msg = self._load_error or "ASR 模型未加载" + logger.warning(error_msg) + return f"转写失败: {error_msg}" + + tmp_file_path = None try: # 解码 Base64 音频 audio_bytes = base64.b64decode(audio_base64) # 保存临时文件 - import tempfile with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp_file: tmp_file.write(audio_bytes) tmp_file_path = tmp_file.name - try: - # 调用 OpenAI Whisper API - with open(tmp_file_path, "rb") as audio_file: - transcript = self.client.audio.transcriptions.create( - model="whisper-1", - file=audio_file, - language="zh" # 中文 - ) - return transcript.text - finally: - # 清理临时文件 - import os - if os.path.exists(tmp_file_path): - os.remove(tmp_file_path) + # 使用 faster-whisper 转写 + # language="zh" 指定中文,可以提高识别速度 + # beam_size=5 是默认值,可以调整 + segments, info = self.model.transcribe( + tmp_file_path, + language="zh", + beam_size=5, + vad_filter=True, # 启用 VAD 过滤静音部分 + vad_parameters=dict( + min_silence_duration_ms=500, # 最小静音时长 + ) + ) + + # 合并所有转写片段 + transcript_text = "".join(segment.text for segment in segments) + + logger.info(f"ASR 转写完成: 语言={info.language}, 概率={info.language_probability:.2f}, 文本长度={len(transcript_text)}") + + return transcript_text.strip() if transcript_text else "" + except Exception as e: - # 出错时返回错误信息 + logger.error(f"ASR 转写失败: {e}", exc_info=True) return f"转写失败: {str(e)}" + finally: + # 清理临时文件 + if tmp_file_path and os.path.exists(tmp_file_path): + try: + os.remove(tmp_file_path) + except Exception: + pass # 全局实例