feat(conversation): TTS 投递与 WebSocket 管线;客户端播放门禁与会话页联动;COS 键与迁移脚本调整
This commit is contained in:
@@ -16,6 +16,7 @@ class MessageType(str, Enum):
|
||||
TRANSCRIPT = "transcript"
|
||||
AGENT_RESPONSE = "agent_response"
|
||||
TTS_AUDIO = "tts_audio"
|
||||
TTS_CANCEL = "tts_cancel"
|
||||
END_CONVERSATION = "end_conversation"
|
||||
MEMOIR_UPDATE = "memoir_update"
|
||||
ERROR = "error"
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
@@ -19,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.agents.chat import ChatOrchestrator
|
||||
from app.core.agent_logging import agent_summary_enabled
|
||||
from app.core.config import settings
|
||||
from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
|
||||
from app.features.conversation.history_store import ConversationHistoryStore
|
||||
@@ -35,6 +36,17 @@ from app.features.user.models import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 客户端发送 tts_cancel 时递增;process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段
|
||||
_tts_cancel_epoch: dict[str, int] = {}
|
||||
|
||||
|
||||
def bump_tts_cancel_epoch(conversation_id: str) -> None:
|
||||
_tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1
|
||||
|
||||
|
||||
def _tts_epoch_value(conversation_id: str) -> int:
|
||||
return _tts_cancel_epoch.get(conversation_id, 0)
|
||||
|
||||
|
||||
def _tts_object_ext(codec: str) -> str:
|
||||
c = (codec or "mp3").lower().lstrip(".")
|
||||
@@ -58,10 +70,14 @@ async def _send_tts_audio(
|
||||
*,
|
||||
chunk_index: int,
|
||||
chunk_total: int,
|
||||
assistant_message_id: str | None,
|
||||
tts_epoch_start: int,
|
||||
) -> str | None:
|
||||
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
|
||||
if not settings.enable_tts:
|
||||
return None
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
return None
|
||||
try:
|
||||
tts = get_tts_provider()
|
||||
audio_bytes = await tts.synthesize(text)
|
||||
@@ -70,23 +86,30 @@ async def _send_tts_audio(
|
||||
"TTS skipped: synthesize returned empty. Check TTS config in .env"
|
||||
)
|
||||
return None
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
return None
|
||||
ext = _tts_object_ext(settings.tts_codec)
|
||||
content_type = _tts_codec_to_content_type(settings.tts_codec)
|
||||
storage = get_object_storage()
|
||||
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
|
||||
public_url = storage.upload(key, audio_bytes, content_type)
|
||||
# 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL
|
||||
playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
|
||||
payload_data: Dict[str, Any] = {
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
||||
"format": settings.tts_codec,
|
||||
"audio_url": playback_url,
|
||||
"index": chunk_index,
|
||||
"total": chunk_total,
|
||||
}
|
||||
if assistant_message_id:
|
||||
payload_data["assistant_message_id"] = assistant_message_id
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.TTS_AUDIO,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
||||
"format": settings.tts_codec,
|
||||
"audio_url": public_url,
|
||||
"index": chunk_index,
|
||||
"total": chunk_total,
|
||||
},
|
||||
"data": payload_data,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
@@ -565,7 +588,7 @@ async def process_user_message(
|
||||
|
||||
segment.agent_response = "\n\n".join(responses)
|
||||
_mark_conversation_active(conversation)
|
||||
await store.record_human_ai_turn(
|
||||
ai_msg_id = await store.record_human_ai_turn(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
responses=responses,
|
||||
@@ -576,7 +599,10 @@ async def process_user_message(
|
||||
tts_audio_urls=None,
|
||||
segment_id=segment.id,
|
||||
)
|
||||
if not ai_msg_id:
|
||||
return
|
||||
|
||||
tts_epoch_start = _tts_epoch_value(conversation_id)
|
||||
n = len(responses)
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(
|
||||
@@ -594,14 +620,20 @@ async def process_user_message(
|
||||
)
|
||||
url = None
|
||||
if not skip_tts:
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
break
|
||||
url = await _send_tts_audio(
|
||||
conversation_id,
|
||||
response_text,
|
||||
chunk_index=i,
|
||||
chunk_total=n,
|
||||
assistant_message_id=ai_msg_id,
|
||||
tts_epoch_start=tts_epoch_start,
|
||||
)
|
||||
if url:
|
||||
tts_urls.append(url)
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
break
|
||||
if i < n - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from app.features.conversation.ws.pipeline import (
|
||||
_mark_conversation_active,
|
||||
_voice_session_id_from_client_segment_id,
|
||||
background_runner,
|
||||
bump_tts_cancel_epoch,
|
||||
chat_orchestrator,
|
||||
cleanup_segment_states,
|
||||
get_or_create_segment_state,
|
||||
@@ -604,6 +605,9 @@ async def websocket_endpoint(
|
||||
},
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.TTS_CANCEL:
|
||||
bump_tts_cancel_epoch(conversation_id)
|
||||
|
||||
elif msg_type == MessageType.END_CONVERSATION:
|
||||
conversation.status = "ended"
|
||||
conversation.ended_at = datetime.now(timezone.utc)
|
||||
|
||||
Reference in New Issue
Block a user