feat(conversation): TTS 投递与 WebSocket 管线;客户端播放门禁与会话页联动;COS 键与迁移脚本调整

This commit is contained in:
Kevin
2026-03-26 15:51:24 +08:00
parent c23931ec91
commit d990399112
22 changed files with 630 additions and 74 deletions

View File

@@ -6,6 +6,13 @@ from typing import Any
from urllib.parse import urlparse
from app.core.config import settings
from app.core.logging import get_logger
from app.ports.storage import ObjectStorage
logger = get_logger(__name__)
# 客户端再读 TTS / 拉取音频:预签名有效期(秒),与移动端会话长度匹配
TTS_PRESIGNED_EXPIRES_SEC = 86_400
def extract_cos_object_key_if_owned(url: str | None) -> str | None:
@@ -75,3 +82,40 @@ def collect_cos_keys_from_tts_url_list(urls: list[str] | None) -> set[str]:
if k:
keys.add(k)
return keys
def presign_tts_urls_for_playback(
urls: list[str],
storage: ObjectStorage | None,
*,
expires: int = TTS_PRESIGNED_EXPIRES_SEC,
) -> list[str]:
"""
将本环境 COS 直链替换为预签名下载 URL私有桶下匿名 GET 会 AccessDenied
目的与回忆录 `normalize_image_assets_for_api` 中对 `get_download_url` 的用法一致。
非本环境 URL 或无法解析 key 时原样返回。
"""
if not storage or not urls:
return list(urls)
out: list[str] = []
for u in urls:
if not isinstance(u, str):
continue
s = u.strip()
if not s:
continue
key = extract_cos_object_key_if_owned(s)
if key:
try:
out.append(storage.get_url(key, expires=expires))
except Exception as exc:
logger.warning(
"presign tts url failed, keeping original url: key={} err={}",
key,
exc,
)
out.append(s)
else:
out.append(s)
return out

View File

@@ -118,7 +118,7 @@ class RedisService:
async def append_tts_audio_url_to_last_ai_message(
self, conversation_id: str, url: str
) -> bool:
"""向最近一条 AI 消息的 ttsAudioUrls 追加 COS 公开 URL。"""
"""向最近一条 AI 消息的 ttsAudioUrls 追加 upload 返回的 canonical URL非预签名。客户端通过 GET /messages 等出口收到预签名 URL。"""
if not url:
return False
try:

View File

@@ -85,9 +85,9 @@ class ConversationHistoryStore:
audio_duration_seconds: int | None,
tts_audio_urls: list[str] | None,
segment_id: str | None,
) -> None:
) -> str | None:
if not responses:
return
return None
human_ts = user_message_timestamp or _utc_now()
if human_ts.tzinfo is None:
human_ts = human_ts.replace(tzinfo=timezone.utc)
@@ -122,6 +122,7 @@ class ConversationHistoryStore:
await self._touch_conversation(conversation_id, occurred_at=ai_ts)
await self._db.commit()
await self._sync_redis_best_effort(conversation_id)
return ai.id
async def attach_ai_tts_audio_urls(
self,

View File

@@ -19,6 +19,7 @@ from app.features.conversation.models import Conversation
from app.features.conversation.session_history import (
conversation_messages_to_redis_history,
)
from app.features.conversation.tts_delivery import apply_presigned_tts_urls_to_messages
from app.features.memory import repo as memory_repo
from app.features.quota.service import QuotaService
from app.ports.storage import ObjectStorage
@@ -248,11 +249,13 @@ class ConversationService:
conv = await self.get_or_404(conversation_id, user_id)
try:
history = await self.ensure_redis_history_from_db(conversation_id)
return _build_messages_from_history(
messages = _build_messages_from_history(
conversation_id=conversation_id,
history=history,
fallback_timestamp=conv.started_at,
)
apply_presigned_tts_urls_to_messages(messages, self._object_storage)
return messages
except Exception:
return []

View File

@@ -0,0 +1,28 @@
"""
对话 TTS 音频 URL 下发到客户端。
与回忆录章节图片一致:私有桶下不能把「直链」当公开可读 URL 使用,应对 COS object key
生成预签名下载地址后再交给 App参见 `normalize_image_assets_for_api` 中的 `get_download_url`)。
持久化DB / Redis仍保存 upload 返回的 canonical URL仅在 API 响应与 WS 实时下发时做 presign。
"""
from __future__ import annotations
from app.core.cos_url_keys import presign_tts_urls_for_playback
from app.ports.storage import ObjectStorage
def apply_presigned_tts_urls_to_messages(
messages: list[dict],
storage: ObjectStorage | None,
) -> None:
"""就地改写助手消息的 `ttsAudioUrls` 为预签名 URL无 storage 时不变。"""
if not storage:
return
for m in messages:
tts = m.get("ttsAudioUrls")
if not isinstance(tts, list) or not tts:
continue
str_urls = [x for x in tts if isinstance(x, str)]
m["ttsAudioUrls"] = presign_tts_urls_for_playback(str_urls, storage)

View File

@@ -16,6 +16,7 @@ class MessageType(str, Enum):
TRANSCRIPT = "transcript"
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
TTS_CANCEL = "tts_cancel"
END_CONVERSATION = "end_conversation"
MEMOIR_UPDATE = "memoir_update"
ERROR = "error"

View File

@@ -6,7 +6,7 @@ import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from app.core.logging import get_logger
@@ -19,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat import ChatOrchestrator
from app.core.agent_logging import agent_summary_enabled
from app.core.config import settings
from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
from app.features.conversation.history_store import ConversationHistoryStore
@@ -35,6 +36,17 @@ from app.features.user.models import User
logger = get_logger(__name__)
# 客户端发送 tts_cancel 时递增process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段
_tts_cancel_epoch: dict[str, int] = {}
def bump_tts_cancel_epoch(conversation_id: str) -> None:
_tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1
def _tts_epoch_value(conversation_id: str) -> int:
return _tts_cancel_epoch.get(conversation_id, 0)
def _tts_object_ext(codec: str) -> str:
c = (codec or "mp3").lower().lstrip(".")
@@ -58,10 +70,14 @@ async def _send_tts_audio(
*,
chunk_index: int,
chunk_total: int,
assistant_message_id: str | None,
tts_epoch_start: int,
) -> str | None:
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
if not settings.enable_tts:
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
return None
try:
tts = get_tts_provider()
audio_bytes = await tts.synthesize(text)
@@ -70,23 +86,30 @@ async def _send_tts_audio(
"TTS skipped: synthesize returned empty. Check TTS config in .env"
)
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
return None
ext = _tts_object_ext(settings.tts_codec)
content_type = _tts_codec_to_content_type(settings.tts_codec)
storage = get_object_storage()
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
public_url = storage.upload(key, audio_bytes, content_type)
# 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL
playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
payload_data: Dict[str, Any] = {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
"audio_url": playback_url,
"index": chunk_index,
"total": chunk_total,
}
if assistant_message_id:
payload_data["assistant_message_id"] = assistant_message_id
await manager.send_message(
conversation_id,
{
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
"audio_url": public_url,
"index": chunk_index,
"total": chunk_total,
},
"data": payload_data,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
@@ -565,7 +588,7 @@ async def process_user_message(
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
await store.record_human_ai_turn(
ai_msg_id = await store.record_human_ai_turn(
conversation_id=conversation_id,
user_message=user_message,
responses=responses,
@@ -576,7 +599,10 @@ async def process_user_message(
tts_audio_urls=None,
segment_id=segment.id,
)
if not ai_msg_id:
return
tts_epoch_start = _tts_epoch_value(conversation_id)
n = len(responses)
for i, response_text in enumerate(responses):
await manager.send_message(
@@ -594,14 +620,20 @@ async def process_user_message(
)
url = None
if not skip_tts:
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
url = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
assistant_message_id=ai_msg_id,
tts_epoch_start=tts_epoch_start,
)
if url:
tts_urls.append(url)
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
if i < n - 1:
await asyncio.sleep(0.5)

View File

@@ -26,6 +26,7 @@ from app.features.conversation.ws.pipeline import (
_mark_conversation_active,
_voice_session_id_from_client_segment_id,
background_runner,
bump_tts_cancel_epoch,
chat_orchestrator,
cleanup_segment_states,
get_or_create_segment_state,
@@ -604,6 +605,9 @@ async def websocket_endpoint(
},
)
elif msg_type == MessageType.TTS_CANCEL:
bump_tts_cancel_epoch(conversation_id)
elif msg_type == MessageType.END_CONVERSATION:
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)