@@ -65,6 +65,17 @@ TENCENT_SECRET_ID=your_tencent_asr_secret_id
|
||||
TENCENT_SECRET_KEY=your_tencent_asr_secret_key
|
||||
# TENCENT_ASR_APP_ID=
|
||||
|
||||
# =============================================================================
|
||||
# TTS Provider
|
||||
# openai | tencent
|
||||
# =============================================================================
|
||||
TTS_PROVIDER=openai
|
||||
# 仅 TTS_PROVIDER=openai 时需要
|
||||
# OPENAI_API_KEY=your_openai_api_key
|
||||
# 仅 TTS_PROVIDER=tencent 时生效,与 ASR 共用 TENCENT_SECRET_ID / TENCENT_SECRET_KEY
|
||||
TTS_VOICE_TYPE=1001
|
||||
TTS_CODEC=mp3
|
||||
|
||||
# =============================================================================
|
||||
# WeChat Pay
|
||||
# =============================================================================
|
||||
|
||||
144
api/app/adapters/tts/tencent_tts.py
Normal file
144
api/app/adapters/tts/tencent_tts.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Tencent Cloud TTS adapter — implements TTSProvider port."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# OpenAI voice name -> Tencent VoiceType ID
|
||||
VOICE_MAP: dict[str, int] = {
|
||||
"alloy": 1001,
|
||||
"echo": 1002,
|
||||
"fable": 1003,
|
||||
"onyx": 1004,
|
||||
"nova": 1005,
|
||||
"shimmer": 1006,
|
||||
}
|
||||
|
||||
# 中文 150 字 / 英文 500 字母,取保守值
|
||||
MAX_CHARS_PER_REQUEST = 150
|
||||
|
||||
|
||||
def _chunk_text(text: str, max_chars: int = MAX_CHARS_PER_REQUEST) -> list[str]:
|
||||
"""Split text into chunks within API limit."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
chunks: list[str] = []
|
||||
# Split by sentence boundaries first
|
||||
pattern = r"[。!?.!?\n]+"
|
||||
parts = re.split(f"({pattern})", text)
|
||||
current = ""
|
||||
for i, p in enumerate(parts):
|
||||
if re.match(pattern, p):
|
||||
current += p
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
current = ""
|
||||
else:
|
||||
if len(current) + len(p) <= max_chars:
|
||||
current += p
|
||||
else:
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
current = ""
|
||||
# Single part exceeds limit, split by length
|
||||
while p:
|
||||
chunk = p[:max_chars]
|
||||
p = p[max_chars:]
|
||||
chunks.append(chunk)
|
||||
if current.strip():
|
||||
chunks.append(current.strip())
|
||||
return chunks
|
||||
|
||||
|
||||
class TencentTTSProvider:
|
||||
def __init__(
|
||||
self,
|
||||
secret_id: str,
|
||||
secret_key: str,
|
||||
voice_type: int = 1001,
|
||||
codec: str = "mp3",
|
||||
):
|
||||
self._secret_id = secret_id
|
||||
self._secret_key = secret_key
|
||||
self._voice_type = voice_type
|
||||
self._codec = codec
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
try:
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.tts.v20190823 import tts_client
|
||||
|
||||
cred = credential.Credential(self._secret_id, self._secret_key)
|
||||
http_profile = HttpProfile()
|
||||
http_profile.endpoint = "tts.tencentcloudapi.com"
|
||||
client_profile = ClientProfile()
|
||||
client_profile.httpProfile = http_profile
|
||||
self._client = tts_client.TtsClient(cred, "", client_profile)
|
||||
return self._client
|
||||
except Exception as e:
|
||||
logger.error("Tencent TTS client init failed: %s", e)
|
||||
return None
|
||||
|
||||
def _synthesize_sync(self, text: str, voice_type: int) -> bytes:
|
||||
client = self._get_client()
|
||||
if not client:
|
||||
return b""
|
||||
try:
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
from tencentcloud.tts.v20190823 import models
|
||||
|
||||
req = models.TextToVoiceRequest()
|
||||
req.Text = text
|
||||
req.SessionId = uuid.uuid4().hex
|
||||
req.VoiceType = voice_type
|
||||
req.PrimaryLanguage = 1
|
||||
req.SampleRate = 16000
|
||||
req.Codec = self._codec
|
||||
|
||||
resp = client.TextToVoice(req)
|
||||
if not resp or not resp.Audio:
|
||||
return b""
|
||||
return base64.b64decode(resp.Audio)
|
||||
except TencentCloudSDKException as e:
|
||||
logger.error("Tencent TTS SDK error: %s", e)
|
||||
return b""
|
||||
except Exception as e:
|
||||
logger.error("Tencent TTS synthesize failed: %s", e)
|
||||
return b""
|
||||
|
||||
async def synthesize(self, text: str, voice: str = "alloy") -> bytes:
|
||||
if not self._secret_id or not self._secret_key:
|
||||
logger.error("Tencent TTS credentials not configured")
|
||||
return b""
|
||||
|
||||
voice_type = VOICE_MAP.get(voice.lower(), self._voice_type)
|
||||
chunks = _chunk_text(text)
|
||||
if not chunks:
|
||||
return b""
|
||||
|
||||
results: list[bytes] = []
|
||||
for chunk in chunks:
|
||||
audio = await asyncio.to_thread(
|
||||
self._synthesize_sync, chunk, voice_type
|
||||
)
|
||||
if not audio:
|
||||
return b""
|
||||
results.append(audio)
|
||||
|
||||
return b"".join(results)
|
||||
@@ -62,6 +62,11 @@ class Settings(BaseSettings):
|
||||
# ── OpenAI (TTS) ─────────────────────────────────────────
|
||||
openai_api_key: str = ""
|
||||
|
||||
# ── TTS ─────────────────────────────────────────────────
|
||||
tts_provider: str = "openai"
|
||||
tts_voice_type: int = 1001
|
||||
tts_codec: str = "mp3"
|
||||
|
||||
# ── WeChat Pay ───────────────────────────────────────────
|
||||
wechat_pay_app_id: str = ""
|
||||
wechat_pay_mch_id: str = ""
|
||||
|
||||
@@ -60,6 +60,15 @@ def get_llm_provider() -> LLMProvider:
|
||||
|
||||
@lru_cache
|
||||
def get_tts_provider() -> TTSProvider:
|
||||
if settings.tts_provider == "tencent":
|
||||
from app.adapters.tts.tencent_tts import TencentTTSProvider
|
||||
|
||||
return TencentTTSProvider(
|
||||
secret_id=settings.tencent_secret_id,
|
||||
secret_key=settings.tencent_secret_key,
|
||||
voice_type=settings.tts_voice_type,
|
||||
codec=settings.tts_codec,
|
||||
)
|
||||
from app.adapters.tts.openai_tts import OpenAITTSProvider
|
||||
|
||||
return OpenAITTSProvider(api_key=settings.openai_api_key)
|
||||
|
||||
@@ -26,11 +26,31 @@ from app.features.conversation.ws.profile_collector import (
|
||||
get_missing_profile_fields,
|
||||
)
|
||||
from app.features.user.models import User
|
||||
from app.core.dependencies import get_asr_provider
|
||||
from app.core.config import settings
|
||||
from app.core.dependencies import get_asr_provider, get_tts_provider
|
||||
from app.features.memoir.state_service import get_or_create_state
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def _send_tts_audio(conversation_id: str, text: str) -> None:
|
||||
"""Synthesize text to speech and send TTS_AUDIO if successful."""
|
||||
try:
|
||||
tts = get_tts_provider()
|
||||
audio_bytes = await tts.synthesize(text)
|
||||
if audio_bytes:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TTS_AUDIO,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
||||
"format": settings.tts_codec,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("TTS synthesize failed: %s", e)
|
||||
|
||||
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
|
||||
conversation_agent = ConversationAgent()
|
||||
memory_agent = MemoryAgent()
|
||||
@@ -447,6 +467,7 @@ async def process_user_message(
|
||||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await _send_tts_audio(conversation_id, response_text)
|
||||
if i < len(responses) - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
return
|
||||
@@ -498,6 +519,7 @@ async def process_user_message(
|
||||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await _send_tts_audio(conversation_id, response_text)
|
||||
if i < len(responses) - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user