feat: 扩展后端WebSocket和语音识别功能
- 扩展websocket.py支持语音消息 - 优化asr_service.py语音识别服务 - 更新main.py和requirements.txt - 更新.env.production配置 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -47,3 +47,13 @@ TENCENT_SMS_TEMPLATE_ID=2592163
|
|||||||
# 如果遇到 TemplateParamSetNotMatchApprovedTemplate 错误,请检查腾讯云控制台中的模板配置
|
# 如果遇到 TemplateParamSetNotMatchApprovedTemplate 错误,请检查腾讯云控制台中的模板配置
|
||||||
# 并根据实际模板参数数量设置此值
|
# 并根据实际模板参数数量设置此值
|
||||||
TENCENT_SMS_TEMPLATE_PARAM_COUNT=1
|
TENCENT_SMS_TEMPLATE_PARAM_COUNT=1
|
||||||
|
|
||||||
|
# CPU 环境(推荐 small + int8)
|
||||||
|
ASR_MODEL_SIZE=small
|
||||||
|
ASR_DEVICE=cpu
|
||||||
|
ASR_COMPUTE_TYPE=int8
|
||||||
|
|
||||||
|
# GPU 环境(推荐 medium + float16)
|
||||||
|
# ASR_MODEL_SIZE=medium
|
||||||
|
# ASR_DEVICE=cuda
|
||||||
|
# ASR_COMPUTE_TYPE=float16
|
||||||
12
api/main.py
12
api/main.py
@@ -94,6 +94,7 @@ app = FastAPI(title="Life Echo API", version="1.0.0")
|
|||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
"""应用启动事件"""
|
"""应用启动事件"""
|
||||||
|
import asyncio
|
||||||
logger.info("=" * 50)
|
logger.info("=" * 50)
|
||||||
logger.info("Life Echo API 正在启动...")
|
logger.info("Life Echo API 正在启动...")
|
||||||
logger.info("=" * 50)
|
logger.info("=" * 50)
|
||||||
@@ -106,6 +107,17 @@ async def startup_event():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Redis 连接失败(会话存储将不可用): {e}")
|
logger.warning(f"Redis 连接失败(会话存储将不可用): {e}")
|
||||||
|
|
||||||
|
# 检查并预加载 ASR 模型(在后台线程执行,避免阻塞启动)
|
||||||
|
try:
|
||||||
|
from services.asr_service import asr_service
|
||||||
|
asr_ready = await asyncio.to_thread(asr_service.ensure_ready)
|
||||||
|
if asr_ready:
|
||||||
|
logger.info("ASR 模型已就绪(本地 Whisper)")
|
||||||
|
else:
|
||||||
|
logger.warning("ASR 模型未就绪,语音转写将不可用")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"ASR 初始化检查失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
|
|||||||
@@ -38,9 +38,8 @@ httpx==0.27.0
|
|||||||
python-jose[cryptography]==3.3.0
|
python-jose[cryptography]==3.3.0
|
||||||
bcrypt>=4.0.0
|
bcrypt>=4.0.0
|
||||||
|
|
||||||
# Audio Processing (optional, for future ASR/TTS integration)
|
# Audio Processing - Local Whisper ASR
|
||||||
# pydub==0.25.1
|
faster-whisper>=1.0.0
|
||||||
# speech-recognition==3.10.4
|
|
||||||
|
|
||||||
# Image Processing
|
# Image Processing
|
||||||
Pillow>=10.0.0
|
Pillow>=10.0.0
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from database.models import Conversation, Segment
|
|||||||
from database.models import User as UserModel
|
from database.models import User as UserModel
|
||||||
from services.auth_service import verify_token
|
from services.auth_service import verify_token
|
||||||
from services.memoir_state_service import get_or_create_state
|
from services.memoir_state_service import get_or_create_state
|
||||||
|
from services.asr_service import asr_service
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -28,8 +29,9 @@ class MessageType(str, Enum):
|
|||||||
"""WebSocket 消息类型"""
|
"""WebSocket 消息类型"""
|
||||||
CONNECT = "connect"
|
CONNECT = "connect"
|
||||||
AUDIO_CHUNK = "audio_chunk"
|
AUDIO_CHUNK = "audio_chunk"
|
||||||
|
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
|
||||||
TEXT = "text" # 文本消息
|
TEXT = "text" # 文本消息
|
||||||
TRANSCRIPT = "transcript"
|
TRANSCRIPT = "transcript" # 语音转文字结果
|
||||||
AGENT_RESPONSE = "agent_response"
|
AGENT_RESPONSE = "agent_response"
|
||||||
TTS_AUDIO = "tts_audio"
|
TTS_AUDIO = "tts_audio"
|
||||||
END_CONVERSATION = "end_conversation"
|
END_CONVERSATION = "end_conversation"
|
||||||
@@ -190,6 +192,70 @@ async def websocket_endpoint(
|
|||||||
manager=manager
|
manager=manager
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||||||
|
# 处理完整音频消息(类似微信语音)
|
||||||
|
data = message.get("data", {})
|
||||||
|
audio_base64 = data.get("audio_base64", "")
|
||||||
|
audio_duration = data.get("duration", 0)
|
||||||
|
|
||||||
|
if audio_base64:
|
||||||
|
logger.info(f"收到音频消息,时长: {audio_duration}s")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. ASR 转写
|
||||||
|
transcript_text = await asr_service.transcribe(audio_base64)
|
||||||
|
logger.info(f"ASR 转写结果: {transcript_text}")
|
||||||
|
|
||||||
|
# 2. 发送转写结果给客户端
|
||||||
|
await manager.send_message(conversation_id, {
|
||||||
|
"type": MessageType.TRANSCRIPT,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"data": {
|
||||||
|
"text": transcript_text,
|
||||||
|
"audio_duration": audio_duration
|
||||||
|
},
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
# 3. 保存段落到数据库(包含转写文本和音频信息)
|
||||||
|
segment = Segment(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
transcript_text=transcript_text,
|
||||||
|
audio_url=f"audio:{audio_duration}s", # 简化存储,标记为音频消息
|
||||||
|
processed=False
|
||||||
|
)
|
||||||
|
db.add(segment)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(segment)
|
||||||
|
await manager.background_runner.queue_message(conversation.user_id, segment.id)
|
||||||
|
|
||||||
|
# 4. Agent 生成回应(基于转写文本)
|
||||||
|
if transcript_text and not transcript_text.startswith("转写失败"):
|
||||||
|
await process_user_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
user_message=transcript_text,
|
||||||
|
conversation=conversation,
|
||||||
|
segment=segment,
|
||||||
|
db=db,
|
||||||
|
manager=manager
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 转写失败,发送错误消息
|
||||||
|
await manager.send_message(conversation_id, {
|
||||||
|
"type": MessageType.ERROR,
|
||||||
|
"data": {"message": "语音转写失败,请重试或使用文字输入"},
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理音频消息失败: {e}", exc_info=True)
|
||||||
|
await manager.send_message(conversation_id, {
|
||||||
|
"type": MessageType.ERROR,
|
||||||
|
"data": {"message": f"处理音频消息失败: {str(e)}"},
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
elif msg_type == MessageType.END_CONVERSATION:
|
elif msg_type == MessageType.END_CONVERSATION:
|
||||||
# 结束对话
|
# 结束对话
|
||||||
conversation.status = "ended"
|
conversation.status = "ended"
|
||||||
|
|||||||
@@ -1,23 +1,95 @@
|
|||||||
"""
|
"""
|
||||||
ASR 服务:语音转文字
|
ASR 服务:语音转文字
|
||||||
|
使用本地 faster-whisper 模型进行语音识别
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
# 可选模型: tiny, base, small, medium, large-v2, large-v3
|
||||||
|
# tiny/base 适合 CPU,small/medium 需要更多资源,large 需要 GPU
|
||||||
|
ASR_MODEL_SIZE = os.getenv("ASR_MODEL_SIZE", "small")
|
||||||
|
ASR_DEVICE = os.getenv("ASR_DEVICE", "auto") # auto, cpu, cuda
|
||||||
|
ASR_COMPUTE_TYPE = os.getenv("ASR_COMPUTE_TYPE", "auto") # auto, int8, float16, float32
|
||||||
|
|
||||||
|
|
||||||
class ASRService:
|
class ASRService:
|
||||||
"""ASR 服务(语音转文字)"""
|
"""
|
||||||
|
ASR 服务(语音转文字)
|
||||||
|
使用 faster-whisper 本地模型
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
api_key = os.getenv("OPENAI_API_KEY", "")
|
self.model = None
|
||||||
if api_key:
|
self._model_loaded = False
|
||||||
self.client = OpenAI(api_key=api_key)
|
self._load_error = None
|
||||||
else:
|
|
||||||
self.client = None
|
|
||||||
|
|
||||||
async def transcribe(self, audio_base64: str) -> str | None:
|
def _load_model(self) -> bool:
|
||||||
|
"""加载模型(首次调用时执行,后续直接返回)。返回是否加载成功。"""
|
||||||
|
if self._model_loaded:
|
||||||
|
return self.model is not None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
logger.info(f"正在加载 Whisper 模型: {ASR_MODEL_SIZE}, device={ASR_DEVICE}, compute_type={ASR_COMPUTE_TYPE}")
|
||||||
|
|
||||||
|
# 确定设备和计算类型
|
||||||
|
device = ASR_DEVICE
|
||||||
|
compute_type = ASR_COMPUTE_TYPE
|
||||||
|
|
||||||
|
if device == "auto":
|
||||||
|
# 自动检测:优先使用 CUDA,否则使用 CPU
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
except ImportError:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
if compute_type == "auto":
|
||||||
|
# 根据设备自动选择计算类型
|
||||||
|
if device == "cuda":
|
||||||
|
compute_type = "float16" # GPU 使用 float16
|
||||||
|
else:
|
||||||
|
compute_type = "int8" # CPU 使用 int8 量化,速度更快
|
||||||
|
|
||||||
|
self.model = WhisperModel(
|
||||||
|
ASR_MODEL_SIZE,
|
||||||
|
device=device,
|
||||||
|
compute_type=compute_type
|
||||||
|
)
|
||||||
|
|
||||||
|
self._model_loaded = True
|
||||||
|
logger.info(f"Whisper 模型加载成功: {ASR_MODEL_SIZE} on {device} ({compute_type})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
self._load_error = "faster-whisper 未安装,请运行: pip install faster-whisper"
|
||||||
|
logger.error(self._load_error)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self._load_error = f"加载 Whisper 模型失败: {str(e)}"
|
||||||
|
logger.error(self._load_error, exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def ensure_ready(self) -> bool:
|
||||||
|
"""
|
||||||
|
确保 ASR 模型已就绪(用于启动时预加载与检查)。
|
||||||
|
可在应用初始化时调用;为同步阻塞调用,建议在后台线程执行。
|
||||||
|
返回是否就绪。
|
||||||
|
"""
|
||||||
|
return self._load_model()
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""检查 ASR 模型是否已加载并可用。"""
|
||||||
|
return self._model_loaded and self.model is not None
|
||||||
|
|
||||||
|
async def transcribe(self, audio_base64: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
转写音频为文字
|
转写音频为文字
|
||||||
|
|
||||||
@@ -25,39 +97,56 @@ class ASRService:
|
|||||||
audio_base64: Base64 编码的音频数据
|
audio_base64: Base64 编码的音频数据
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
转写文本
|
转写文本,失败时返回错误信息
|
||||||
"""
|
"""
|
||||||
if not self.client:
|
# 懒加载模型
|
||||||
# 如果没有配置 API Key,返回模拟数据
|
self._load_model()
|
||||||
return "这是模拟的转写文本(请配置 OPENAI_API_KEY 以使用实际 ASR 功能)"
|
|
||||||
|
|
||||||
|
if not self.model:
|
||||||
|
error_msg = self._load_error or "ASR 模型未加载"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
return f"转写失败: {error_msg}"
|
||||||
|
|
||||||
|
tmp_file_path = None
|
||||||
try:
|
try:
|
||||||
# 解码 Base64 音频
|
# 解码 Base64 音频
|
||||||
audio_bytes = base64.b64decode(audio_base64)
|
audio_bytes = base64.b64decode(audio_base64)
|
||||||
|
|
||||||
# 保存临时文件
|
# 保存临时文件
|
||||||
import tempfile
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp_file:
|
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp_file:
|
||||||
tmp_file.write(audio_bytes)
|
tmp_file.write(audio_bytes)
|
||||||
tmp_file_path = tmp_file.name
|
tmp_file_path = tmp_file.name
|
||||||
|
|
||||||
try:
|
# 使用 faster-whisper 转写
|
||||||
# 调用 OpenAI Whisper API
|
# language="zh" 指定中文,可以提高识别速度
|
||||||
with open(tmp_file_path, "rb") as audio_file:
|
# beam_size=5 是默认值,可以调整
|
||||||
transcript = self.client.audio.transcriptions.create(
|
segments, info = self.model.transcribe(
|
||||||
model="whisper-1",
|
tmp_file_path,
|
||||||
file=audio_file,
|
language="zh",
|
||||||
language="zh" # 中文
|
beam_size=5,
|
||||||
|
vad_filter=True, # 启用 VAD 过滤静音部分
|
||||||
|
vad_parameters=dict(
|
||||||
|
min_silence_duration_ms=500, # 最小静音时长
|
||||||
)
|
)
|
||||||
return transcript.text
|
)
|
||||||
|
|
||||||
|
# 合并所有转写片段
|
||||||
|
transcript_text = "".join(segment.text for segment in segments)
|
||||||
|
|
||||||
|
logger.info(f"ASR 转写完成: 语言={info.language}, 概率={info.language_probability:.2f}, 文本长度={len(transcript_text)}")
|
||||||
|
|
||||||
|
return transcript_text.strip() if transcript_text else ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ASR 转写失败: {e}", exc_info=True)
|
||||||
|
return f"转写失败: {str(e)}"
|
||||||
finally:
|
finally:
|
||||||
# 清理临时文件
|
# 清理临时文件
|
||||||
import os
|
if tmp_file_path and os.path.exists(tmp_file_path):
|
||||||
if os.path.exists(tmp_file_path):
|
try:
|
||||||
os.remove(tmp_file_path)
|
os.remove(tmp_file_path)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# 出错时返回错误信息
|
pass
|
||||||
return f"转写失败: {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
# 全局实例
|
||||||
|
|||||||
Reference in New Issue
Block a user