feat(conversation): TTS 投递与 WebSocket 管线;客户端播放门禁与会话页联动;COS 键与迁移脚本调整

This commit is contained in:
Kevin
2026-03-26 15:51:24 +08:00
parent c23931ec91
commit d990399112
22 changed files with 630 additions and 74 deletions

View File

@@ -16,6 +16,7 @@ class MessageType(str, Enum):
TRANSCRIPT = "transcript"
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
TTS_CANCEL = "tts_cancel"
END_CONVERSATION = "end_conversation"
MEMOIR_UPDATE = "memoir_update"
ERROR = "error"

View File

@@ -6,7 +6,7 @@ import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from app.core.logging import get_logger
@@ -19,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat import ChatOrchestrator
from app.core.agent_logging import agent_summary_enabled
from app.core.config import settings
from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
from app.features.conversation.history_store import ConversationHistoryStore
@@ -35,6 +36,17 @@ from app.features.user.models import User
logger = get_logger(__name__)
# 客户端发送 tts_cancel 时递增process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段
_tts_cancel_epoch: dict[str, int] = {}
def bump_tts_cancel_epoch(conversation_id: str) -> None:
_tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1
def _tts_epoch_value(conversation_id: str) -> int:
return _tts_cancel_epoch.get(conversation_id, 0)
def _tts_object_ext(codec: str) -> str:
c = (codec or "mp3").lower().lstrip(".")
@@ -58,10 +70,14 @@ async def _send_tts_audio(
*,
chunk_index: int,
chunk_total: int,
assistant_message_id: str | None,
tts_epoch_start: int,
) -> str | None:
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
if not settings.enable_tts:
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
return None
try:
tts = get_tts_provider()
audio_bytes = await tts.synthesize(text)
@@ -70,23 +86,30 @@ async def _send_tts_audio(
"TTS skipped: synthesize returned empty. Check TTS config in .env"
)
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
return None
ext = _tts_object_ext(settings.tts_codec)
content_type = _tts_codec_to_content_type(settings.tts_codec)
storage = get_object_storage()
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
public_url = storage.upload(key, audio_bytes, content_type)
# 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL
playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
payload_data: Dict[str, Any] = {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
"audio_url": playback_url,
"index": chunk_index,
"total": chunk_total,
}
if assistant_message_id:
payload_data["assistant_message_id"] = assistant_message_id
await manager.send_message(
conversation_id,
{
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
"audio_url": public_url,
"index": chunk_index,
"total": chunk_total,
},
"data": payload_data,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
@@ -565,7 +588,7 @@ async def process_user_message(
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
await store.record_human_ai_turn(
ai_msg_id = await store.record_human_ai_turn(
conversation_id=conversation_id,
user_message=user_message,
responses=responses,
@@ -576,7 +599,10 @@ async def process_user_message(
tts_audio_urls=None,
segment_id=segment.id,
)
if not ai_msg_id:
return
tts_epoch_start = _tts_epoch_value(conversation_id)
n = len(responses)
for i, response_text in enumerate(responses):
await manager.send_message(
@@ -594,14 +620,20 @@ async def process_user_message(
)
url = None
if not skip_tts:
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
url = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
assistant_message_id=ai_msg_id,
tts_epoch_start=tts_epoch_start,
)
if url:
tts_urls.append(url)
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
if i < n - 1:
await asyncio.sleep(0.5)

View File

@@ -26,6 +26,7 @@ from app.features.conversation.ws.pipeline import (
_mark_conversation_active,
_voice_session_id_from_client_segment_id,
background_runner,
bump_tts_cancel_epoch,
chat_orchestrator,
cleanup_segment_states,
get_or_create_segment_state,
@@ -604,6 +605,9 @@ async def websocket_endpoint(
},
)
elif msg_type == MessageType.TTS_CANCEL:
bump_tts_cancel_epoch(conversation_id)
elif msg_type == MessageType.END_CONVERSATION:
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)