feat(conversation): TTS 投递与 WebSocket 管线;客户端播放门禁与会话页联动;COS 键与迁移脚本调整

This commit is contained in:
Kevin
2026-03-26 15:51:24 +08:00
parent c23931ec91
commit d990399112
22 changed files with 630 additions and 74 deletions

View File

@@ -6,6 +6,13 @@ from typing import Any
from urllib.parse import urlparse
from app.core.config import settings
from app.core.logging import get_logger
from app.ports.storage import ObjectStorage
logger = get_logger(__name__)
# 客户端再读 TTS / 拉取音频:预签名有效期(秒),与移动端会话长度匹配
TTS_PRESIGNED_EXPIRES_SEC = 86_400
def extract_cos_object_key_if_owned(url: str | None) -> str | None:
@@ -75,3 +82,40 @@ def collect_cos_keys_from_tts_url_list(urls: list[str] | None) -> set[str]:
if k:
keys.add(k)
return keys
def presign_tts_urls_for_playback(
urls: list[str],
storage: ObjectStorage | None,
*,
expires: int = TTS_PRESIGNED_EXPIRES_SEC,
) -> list[str]:
"""
将本环境 COS 直链替换为预签名下载 URL私有桶下匿名 GET 会 AccessDenied
目的与回忆录 `normalize_image_assets_for_api` 中对 `get_download_url` 的用法一致。
非本环境 URL 或无法解析 key 时原样返回。
"""
if not storage or not urls:
return list(urls)
out: list[str] = []
for u in urls:
if not isinstance(u, str):
continue
s = u.strip()
if not s:
continue
key = extract_cos_object_key_if_owned(s)
if key:
try:
out.append(storage.get_url(key, expires=expires))
except Exception as exc:
logger.warning(
"presign tts url failed, keeping original url: key={} err={}",
key,
exc,
)
out.append(s)
else:
out.append(s)
return out

View File

@@ -118,7 +118,7 @@ class RedisService:
async def append_tts_audio_url_to_last_ai_message(
self, conversation_id: str, url: str
) -> bool:
"""向最近一条 AI 消息的 ttsAudioUrls 追加 COS 公开 URL。"""
"""向最近一条 AI 消息的 ttsAudioUrls 追加 upload 返回的 canonical URL非预签名。客户端通过 GET /messages 等出口收到预签名 URL。"""
if not url:
return False
try:

View File

@@ -85,9 +85,9 @@ class ConversationHistoryStore:
audio_duration_seconds: int | None,
tts_audio_urls: list[str] | None,
segment_id: str | None,
) -> None:
) -> str | None:
if not responses:
return
return None
human_ts = user_message_timestamp or _utc_now()
if human_ts.tzinfo is None:
human_ts = human_ts.replace(tzinfo=timezone.utc)
@@ -122,6 +122,7 @@ class ConversationHistoryStore:
await self._touch_conversation(conversation_id, occurred_at=ai_ts)
await self._db.commit()
await self._sync_redis_best_effort(conversation_id)
return ai.id
async def attach_ai_tts_audio_urls(
self,

View File

@@ -19,6 +19,7 @@ from app.features.conversation.models import Conversation
from app.features.conversation.session_history import (
conversation_messages_to_redis_history,
)
from app.features.conversation.tts_delivery import apply_presigned_tts_urls_to_messages
from app.features.memory import repo as memory_repo
from app.features.quota.service import QuotaService
from app.ports.storage import ObjectStorage
@@ -248,11 +249,13 @@ class ConversationService:
conv = await self.get_or_404(conversation_id, user_id)
try:
history = await self.ensure_redis_history_from_db(conversation_id)
return _build_messages_from_history(
messages = _build_messages_from_history(
conversation_id=conversation_id,
history=history,
fallback_timestamp=conv.started_at,
)
apply_presigned_tts_urls_to_messages(messages, self._object_storage)
return messages
except Exception:
return []

View File

@@ -0,0 +1,28 @@
"""
对话 TTS 音频 URL 下发到客户端。
与回忆录章节图片一致:私有桶下不能把「直链」当公开可读 URL 使用,应对 COS object key
生成预签名下载地址后再交给 App参见 `normalize_image_assets_for_api` 中的 `get_download_url`)。
持久化DB / Redis仍保存 upload 返回的 canonical URL仅在 API 响应与 WS 实时下发时做 presign。
"""
from __future__ import annotations
from app.core.cos_url_keys import presign_tts_urls_for_playback
from app.ports.storage import ObjectStorage
def apply_presigned_tts_urls_to_messages(
messages: list[dict],
storage: ObjectStorage | None,
) -> None:
"""就地改写助手消息的 `ttsAudioUrls` 为预签名 URL无 storage 时不变。"""
if not storage:
return
for m in messages:
tts = m.get("ttsAudioUrls")
if not isinstance(tts, list) or not tts:
continue
str_urls = [x for x in tts if isinstance(x, str)]
m["ttsAudioUrls"] = presign_tts_urls_for_playback(str_urls, storage)

View File

@@ -16,6 +16,7 @@ class MessageType(str, Enum):
TRANSCRIPT = "transcript"
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
TTS_CANCEL = "tts_cancel"
END_CONVERSATION = "end_conversation"
MEMOIR_UPDATE = "memoir_update"
ERROR = "error"

View File

@@ -6,7 +6,7 @@ import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from app.core.logging import get_logger
@@ -19,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat import ChatOrchestrator
from app.core.agent_logging import agent_summary_enabled
from app.core.config import settings
from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
from app.features.conversation.history_store import ConversationHistoryStore
@@ -35,6 +36,17 @@ from app.features.user.models import User
logger = get_logger(__name__)
# 客户端发送 tts_cancel 时递增process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段
_tts_cancel_epoch: dict[str, int] = {}
def bump_tts_cancel_epoch(conversation_id: str) -> None:
_tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1
def _tts_epoch_value(conversation_id: str) -> int:
return _tts_cancel_epoch.get(conversation_id, 0)
def _tts_object_ext(codec: str) -> str:
c = (codec or "mp3").lower().lstrip(".")
@@ -58,10 +70,14 @@ async def _send_tts_audio(
*,
chunk_index: int,
chunk_total: int,
assistant_message_id: str | None,
tts_epoch_start: int,
) -> str | None:
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
if not settings.enable_tts:
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
return None
try:
tts = get_tts_provider()
audio_bytes = await tts.synthesize(text)
@@ -70,23 +86,30 @@ async def _send_tts_audio(
"TTS skipped: synthesize returned empty. Check TTS config in .env"
)
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
return None
ext = _tts_object_ext(settings.tts_codec)
content_type = _tts_codec_to_content_type(settings.tts_codec)
storage = get_object_storage()
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
public_url = storage.upload(key, audio_bytes, content_type)
# 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL
playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
payload_data: Dict[str, Any] = {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
"audio_url": playback_url,
"index": chunk_index,
"total": chunk_total,
}
if assistant_message_id:
payload_data["assistant_message_id"] = assistant_message_id
await manager.send_message(
conversation_id,
{
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
"audio_url": public_url,
"index": chunk_index,
"total": chunk_total,
},
"data": payload_data,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
@@ -565,7 +588,7 @@ async def process_user_message(
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
await store.record_human_ai_turn(
ai_msg_id = await store.record_human_ai_turn(
conversation_id=conversation_id,
user_message=user_message,
responses=responses,
@@ -576,7 +599,10 @@ async def process_user_message(
tts_audio_urls=None,
segment_id=segment.id,
)
if not ai_msg_id:
return
tts_epoch_start = _tts_epoch_value(conversation_id)
n = len(responses)
for i, response_text in enumerate(responses):
await manager.send_message(
@@ -594,14 +620,20 @@ async def process_user_message(
)
url = None
if not skip_tts:
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
url = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
assistant_message_id=ai_msg_id,
tts_epoch_start=tts_epoch_start,
)
if url:
tts_urls.append(url)
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
if i < n - 1:
await asyncio.sleep(0.5)

View File

@@ -26,6 +26,7 @@ from app.features.conversation.ws.pipeline import (
_mark_conversation_active,
_voice_session_id_from_client_segment_id,
background_runner,
bump_tts_cancel_epoch,
chat_orchestrator,
cleanup_segment_states,
get_or_create_segment_state,
@@ -604,6 +605,9 @@ async def websocket_endpoint(
},
)
elif msg_type == MessageType.TTS_CANCEL:
bump_tts_cancel_epoch(conversation_id)
elif msg_type == MessageType.END_CONVERSATION:
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)

View File

@@ -14,12 +14,17 @@ refresh_tokens / segments / sms_verification_codes / users见仓库内历史
2) 目标库已执行 ``alembic upgrade head``(含 pgvector 与当前 ORM 表)。
3) 运行::
3) 运行(仓库内可用 ``uv run python scripts/...``::
cd api && uv run python scripts/migrate_legacy_to_current.py \\
python3 migrate_legacy_to_current.py \\
--legacy-url postgresql://postgres:postgres@localhost:5432/life_echo_legacy \\
--target-url postgresql://postgres:postgres@localhost:5432/life_echo
**仅服务器 + psycopg**:本脚本不依赖项目内其它包,可复制单文件到机器上执行::
pip install 'psycopg[binary]' # 或 python3 -m pip install --user 'psycopg[binary]'
python3 migrate_legacy_to_current.py --legacy-url ... --target-url ...
说明:
- 不会创建 stories / memory_* / conversation_messages 等旧库中不存在的表数据;
- chapterscontent → canonical_markdown按 user_id 关联该用户唯一 book 填 book_id若无书则为 NULL
@@ -31,32 +36,44 @@ refresh_tokens / segments / sms_verification_codes / users见仓库内历史
若目标库已有用户且手机号与某条 legacy 用户冲突(同号不同 id会自动跳过该 legacy 用户及其 books/chapters/
conversations 等关联行,避免违反 ``users.phone`` 唯一约束。新生产库一般为空库,不会触发。
**宿主机上跑脚本(数据库在 Docker Compose 里)**`.env` 里常见主机名 `postgres`,在容器外无法解析。
可直接把 URL 写成 `...@127.0.0.1:5432/...`,或使用 `--db-host 127.0.0.1` 自动替换两个 URL 中的主机名(端口不变)。
"""
from __future__ import annotations
import argparse
import json
import sys
import logging
import uuid
from pathlib import Path
from typing import Any, Literal
_ROOT = Path(__file__).resolve().parents[1]
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
from urllib.parse import urlparse, urlunparse
from psycopg import Connection, connect
from psycopg.rows import dict_row
from psycopg.types.json import Json
from app.core.logging import get_logger
logger = get_logger(__name__)
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
logger = logging.getLogger(__name__)
OnConflict = Literal["upsert", "skip"]
def _replace_url_host(url: str, new_host: str) -> str:
"""将 postgresql URL 中的主机名替换为 new_host保留用户、密码、端口、库名"""
u = urlparse(url)
if not u.netloc or "@" not in u.netloc:
return url
auth, hostport = u.netloc.rsplit("@", 1)
if ":" in hostport:
_old_host, port = hostport.split(":", 1)
new_netloc = f"{auth}@{new_host}:{port}"
else:
new_netloc = f"{auth}@{new_host}"
return urlunparse((u.scheme, new_netloc, u.path, u.params, u.query, u.fragment))
def _open(url: str) -> Connection:
return connect(url, autocommit=False)
@@ -76,7 +93,7 @@ def _legacy_user_ids_skipped_for_phone(
if owner is not None and owner != r["id"]:
skipped.add(r["id"])
logger.warning(
"skip legacy user {} phone={} (target user id={})",
"skip legacy user %s phone=%s (target user id=%s)",
r["id"],
r["phone"],
owner,
@@ -678,11 +695,27 @@ def main() -> None:
action="store_true",
help="Connect and print row counts only; no writes.",
)
p.add_argument(
"--db-host",
metavar="HOST",
default=None,
help=(
"Replace hostname in both URLs (e.g. 127.0.0.1 when running on the host "
"while Postgres is published from Docker; keeps user/password/port/dbname)."
),
)
args = p.parse_args()
on_conflict: OnConflict = args.on_conflict # type: ignore[assignment]
legacy = _open(args.legacy_url)
target = _open(args.target_url)
legacy_url = args.legacy_url
target_url = args.target_url
if args.db_host:
legacy_url = _replace_url_host(legacy_url, args.db_host)
target_url = _replace_url_host(target_url, args.db_host)
logger.info("using db host %s (from --db-host)", args.db_host)
legacy = _open(legacy_url)
target = _open(target_url)
try:
if args.dry_run:
with legacy.cursor(row_factory=dict_row) as cur:
@@ -699,14 +732,14 @@ def main() -> None:
):
cur.execute(f"SELECT COUNT(*) AS c FROM {t}")
c = cur.fetchone()["c"]
logger.info("legacy {} rows={}", t, c)
logger.info("legacy %s rows=%s", t, c)
logger.info("dry-run done")
return
skip_users = _legacy_user_ids_skipped_for_phone(legacy, target)
if skip_users:
logger.info(
"skip {} legacy users due to phone already owned in target",
"skip %d legacy users due to phone already owned in target",
len(skip_users),
)
@@ -724,9 +757,9 @@ def main() -> None:
target.commit()
logger.info(
"migration committed: users={} books={} memoir_states={} "
"conversations={} segments={} orders={} refresh_tokens={} "
"sms={} chapters={} memoir_images={}",
"migration committed: users=%s books=%s memoir_states=%s "
"conversations=%s segments=%s orders=%s refresh_tokens=%s "
"sms=%s chapters=%s memoir_images=%s",
n_users,
n_books,
n_memoir,