feat(conversation): TTS 投递与 WebSocket 管线;客户端播放门禁与会话页联动;COS 键与迁移脚本调整
This commit is contained in:
@@ -6,6 +6,13 @@ from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.ports.storage import ObjectStorage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 客户端再读 TTS / 拉取音频:预签名有效期(秒),与移动端会话长度匹配
|
||||
TTS_PRESIGNED_EXPIRES_SEC = 86_400
|
||||
|
||||
|
||||
def extract_cos_object_key_if_owned(url: str | None) -> str | None:
|
||||
@@ -75,3 +82,40 @@ def collect_cos_keys_from_tts_url_list(urls: list[str] | None) -> set[str]:
|
||||
if k:
|
||||
keys.add(k)
|
||||
return keys
|
||||
|
||||
|
||||
def presign_tts_urls_for_playback(
|
||||
urls: list[str],
|
||||
storage: ObjectStorage | None,
|
||||
*,
|
||||
expires: int = TTS_PRESIGNED_EXPIRES_SEC,
|
||||
) -> list[str]:
|
||||
"""
|
||||
将本环境 COS 直链替换为预签名下载 URL(私有桶下匿名 GET 会 AccessDenied)。
|
||||
|
||||
目的与回忆录 `normalize_image_assets_for_api` 中对 `get_download_url` 的用法一致。
|
||||
非本环境 URL 或无法解析 key 时原样返回。
|
||||
"""
|
||||
if not storage or not urls:
|
||||
return list(urls)
|
||||
out: list[str] = []
|
||||
for u in urls:
|
||||
if not isinstance(u, str):
|
||||
continue
|
||||
s = u.strip()
|
||||
if not s:
|
||||
continue
|
||||
key = extract_cos_object_key_if_owned(s)
|
||||
if key:
|
||||
try:
|
||||
out.append(storage.get_url(key, expires=expires))
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"presign tts url failed, keeping original url: key={} err={}",
|
||||
key,
|
||||
exc,
|
||||
)
|
||||
out.append(s)
|
||||
else:
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
@@ -118,7 +118,7 @@ class RedisService:
|
||||
async def append_tts_audio_url_to_last_ai_message(
|
||||
self, conversation_id: str, url: str
|
||||
) -> bool:
|
||||
"""向最近一条 AI 消息的 ttsAudioUrls 追加 COS 公开 URL。"""
|
||||
"""向最近一条 AI 消息的 ttsAudioUrls 追加 upload 返回的 canonical URL(非预签名)。客户端通过 GET /messages 等出口收到预签名 URL。"""
|
||||
if not url:
|
||||
return False
|
||||
try:
|
||||
|
||||
@@ -85,9 +85,9 @@ class ConversationHistoryStore:
|
||||
audio_duration_seconds: int | None,
|
||||
tts_audio_urls: list[str] | None,
|
||||
segment_id: str | None,
|
||||
) -> None:
|
||||
) -> str | None:
|
||||
if not responses:
|
||||
return
|
||||
return None
|
||||
human_ts = user_message_timestamp or _utc_now()
|
||||
if human_ts.tzinfo is None:
|
||||
human_ts = human_ts.replace(tzinfo=timezone.utc)
|
||||
@@ -122,6 +122,7 @@ class ConversationHistoryStore:
|
||||
await self._touch_conversation(conversation_id, occurred_at=ai_ts)
|
||||
await self._db.commit()
|
||||
await self._sync_redis_best_effort(conversation_id)
|
||||
return ai.id
|
||||
|
||||
async def attach_ai_tts_audio_urls(
|
||||
self,
|
||||
|
||||
@@ -19,6 +19,7 @@ from app.features.conversation.models import Conversation
|
||||
from app.features.conversation.session_history import (
|
||||
conversation_messages_to_redis_history,
|
||||
)
|
||||
from app.features.conversation.tts_delivery import apply_presigned_tts_urls_to_messages
|
||||
from app.features.memory import repo as memory_repo
|
||||
from app.features.quota.service import QuotaService
|
||||
from app.ports.storage import ObjectStorage
|
||||
@@ -248,11 +249,13 @@ class ConversationService:
|
||||
conv = await self.get_or_404(conversation_id, user_id)
|
||||
try:
|
||||
history = await self.ensure_redis_history_from_db(conversation_id)
|
||||
return _build_messages_from_history(
|
||||
messages = _build_messages_from_history(
|
||||
conversation_id=conversation_id,
|
||||
history=history,
|
||||
fallback_timestamp=conv.started_at,
|
||||
)
|
||||
apply_presigned_tts_urls_to_messages(messages, self._object_storage)
|
||||
return messages
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
28
api/app/features/conversation/tts_delivery.py
Normal file
28
api/app/features/conversation/tts_delivery.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
对话 TTS 音频 URL 下发到客户端。
|
||||
|
||||
与回忆录章节图片一致:私有桶下不能把「直链」当公开可读 URL 使用,应对 COS object key
|
||||
生成预签名下载地址后再交给 App(参见 `normalize_image_assets_for_api` 中的 `get_download_url`)。
|
||||
|
||||
持久化(DB / Redis)仍保存 upload 返回的 canonical URL,仅在 API 响应与 WS 实时下发时做 presign。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.core.cos_url_keys import presign_tts_urls_for_playback
|
||||
from app.ports.storage import ObjectStorage
|
||||
|
||||
|
||||
def apply_presigned_tts_urls_to_messages(
|
||||
messages: list[dict],
|
||||
storage: ObjectStorage | None,
|
||||
) -> None:
|
||||
"""就地改写助手消息的 `ttsAudioUrls` 为预签名 URL;无 storage 时不变。"""
|
||||
if not storage:
|
||||
return
|
||||
for m in messages:
|
||||
tts = m.get("ttsAudioUrls")
|
||||
if not isinstance(tts, list) or not tts:
|
||||
continue
|
||||
str_urls = [x for x in tts if isinstance(x, str)]
|
||||
m["ttsAudioUrls"] = presign_tts_urls_for_playback(str_urls, storage)
|
||||
@@ -16,6 +16,7 @@ class MessageType(str, Enum):
|
||||
TRANSCRIPT = "transcript"
|
||||
AGENT_RESPONSE = "agent_response"
|
||||
TTS_AUDIO = "tts_audio"
|
||||
TTS_CANCEL = "tts_cancel"
|
||||
END_CONVERSATION = "end_conversation"
|
||||
MEMOIR_UPDATE = "memoir_update"
|
||||
ERROR = "error"
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
@@ -19,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.agents.chat import ChatOrchestrator
|
||||
from app.core.agent_logging import agent_summary_enabled
|
||||
from app.core.config import settings
|
||||
from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
|
||||
from app.features.conversation.history_store import ConversationHistoryStore
|
||||
@@ -35,6 +36,17 @@ from app.features.user.models import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 客户端发送 tts_cancel 时递增;process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段
|
||||
_tts_cancel_epoch: dict[str, int] = {}
|
||||
|
||||
|
||||
def bump_tts_cancel_epoch(conversation_id: str) -> None:
|
||||
_tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1
|
||||
|
||||
|
||||
def _tts_epoch_value(conversation_id: str) -> int:
|
||||
return _tts_cancel_epoch.get(conversation_id, 0)
|
||||
|
||||
|
||||
def _tts_object_ext(codec: str) -> str:
|
||||
c = (codec or "mp3").lower().lstrip(".")
|
||||
@@ -58,10 +70,14 @@ async def _send_tts_audio(
|
||||
*,
|
||||
chunk_index: int,
|
||||
chunk_total: int,
|
||||
assistant_message_id: str | None,
|
||||
tts_epoch_start: int,
|
||||
) -> str | None:
|
||||
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
|
||||
if not settings.enable_tts:
|
||||
return None
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
return None
|
||||
try:
|
||||
tts = get_tts_provider()
|
||||
audio_bytes = await tts.synthesize(text)
|
||||
@@ -70,23 +86,30 @@ async def _send_tts_audio(
|
||||
"TTS skipped: synthesize returned empty. Check TTS config in .env"
|
||||
)
|
||||
return None
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
return None
|
||||
ext = _tts_object_ext(settings.tts_codec)
|
||||
content_type = _tts_codec_to_content_type(settings.tts_codec)
|
||||
storage = get_object_storage()
|
||||
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
|
||||
public_url = storage.upload(key, audio_bytes, content_type)
|
||||
# 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL
|
||||
playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
|
||||
payload_data: Dict[str, Any] = {
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
||||
"format": settings.tts_codec,
|
||||
"audio_url": playback_url,
|
||||
"index": chunk_index,
|
||||
"total": chunk_total,
|
||||
}
|
||||
if assistant_message_id:
|
||||
payload_data["assistant_message_id"] = assistant_message_id
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.TTS_AUDIO,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
||||
"format": settings.tts_codec,
|
||||
"audio_url": public_url,
|
||||
"index": chunk_index,
|
||||
"total": chunk_total,
|
||||
},
|
||||
"data": payload_data,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
@@ -565,7 +588,7 @@ async def process_user_message(
|
||||
|
||||
segment.agent_response = "\n\n".join(responses)
|
||||
_mark_conversation_active(conversation)
|
||||
await store.record_human_ai_turn(
|
||||
ai_msg_id = await store.record_human_ai_turn(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
responses=responses,
|
||||
@@ -576,7 +599,10 @@ async def process_user_message(
|
||||
tts_audio_urls=None,
|
||||
segment_id=segment.id,
|
||||
)
|
||||
if not ai_msg_id:
|
||||
return
|
||||
|
||||
tts_epoch_start = _tts_epoch_value(conversation_id)
|
||||
n = len(responses)
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(
|
||||
@@ -594,14 +620,20 @@ async def process_user_message(
|
||||
)
|
||||
url = None
|
||||
if not skip_tts:
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
break
|
||||
url = await _send_tts_audio(
|
||||
conversation_id,
|
||||
response_text,
|
||||
chunk_index=i,
|
||||
chunk_total=n,
|
||||
assistant_message_id=ai_msg_id,
|
||||
tts_epoch_start=tts_epoch_start,
|
||||
)
|
||||
if url:
|
||||
tts_urls.append(url)
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
break
|
||||
if i < n - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from app.features.conversation.ws.pipeline import (
|
||||
_mark_conversation_active,
|
||||
_voice_session_id_from_client_segment_id,
|
||||
background_runner,
|
||||
bump_tts_cancel_epoch,
|
||||
chat_orchestrator,
|
||||
cleanup_segment_states,
|
||||
get_or_create_segment_state,
|
||||
@@ -604,6 +605,9 @@ async def websocket_endpoint(
|
||||
},
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.TTS_CANCEL:
|
||||
bump_tts_cancel_epoch(conversation_id)
|
||||
|
||||
elif msg_type == MessageType.END_CONVERSATION:
|
||||
conversation.status = "ended"
|
||||
conversation.ended_at = datetime.now(timezone.utc)
|
||||
|
||||
@@ -14,12 +14,17 @@ refresh_tokens / segments / sms_verification_codes / users(见仓库内历史
|
||||
|
||||
2) 目标库已执行 ``alembic upgrade head``(含 pgvector 与当前 ORM 表)。
|
||||
|
||||
3) 运行::
|
||||
3) 运行(仓库内可用 ``uv run python scripts/...``)::
|
||||
|
||||
cd api && uv run python scripts/migrate_legacy_to_current.py \\
|
||||
python3 migrate_legacy_to_current.py \\
|
||||
--legacy-url postgresql://postgres:postgres@localhost:5432/life_echo_legacy \\
|
||||
--target-url postgresql://postgres:postgres@localhost:5432/life_echo
|
||||
|
||||
**仅服务器 + psycopg**:本脚本不依赖项目内其它包,可复制单文件到机器上执行::
|
||||
|
||||
pip install 'psycopg[binary]' # 或 python3 -m pip install --user 'psycopg[binary]'
|
||||
python3 migrate_legacy_to_current.py --legacy-url ... --target-url ...
|
||||
|
||||
说明:
|
||||
- 不会创建 stories / memory_* / conversation_messages 等旧库中不存在的表数据;
|
||||
- chapters:content → canonical_markdown;按 user_id 关联该用户唯一 book 填 book_id(若无书则为 NULL);
|
||||
@@ -31,32 +36,44 @@ refresh_tokens / segments / sms_verification_codes / users(见仓库内历史
|
||||
|
||||
若目标库已有用户且手机号与某条 legacy 用户冲突(同号不同 id),会自动跳过该 legacy 用户及其 books/chapters/
|
||||
conversations 等关联行,避免违反 ``users.phone`` 唯一约束。新生产库一般为空库,不会触发。
|
||||
|
||||
**宿主机上跑脚本(数据库在 Docker Compose 里)**:`.env` 里常见主机名 `postgres`,在容器外无法解析。
|
||||
可直接把 URL 写成 `...@127.0.0.1:5432/...`,或使用 `--db-host 127.0.0.1` 自动替换两个 URL 中的主机名(端口不变)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_ROOT))
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from psycopg import Connection, connect
|
||||
from psycopg.rows import dict_row
|
||||
from psycopg.types.json import Json
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OnConflict = Literal["upsert", "skip"]
|
||||
|
||||
|
||||
def _replace_url_host(url: str, new_host: str) -> str:
|
||||
"""将 postgresql URL 中的主机名替换为 new_host(保留用户、密码、端口、库名)。"""
|
||||
u = urlparse(url)
|
||||
if not u.netloc or "@" not in u.netloc:
|
||||
return url
|
||||
auth, hostport = u.netloc.rsplit("@", 1)
|
||||
if ":" in hostport:
|
||||
_old_host, port = hostport.split(":", 1)
|
||||
new_netloc = f"{auth}@{new_host}:{port}"
|
||||
else:
|
||||
new_netloc = f"{auth}@{new_host}"
|
||||
return urlunparse((u.scheme, new_netloc, u.path, u.params, u.query, u.fragment))
|
||||
|
||||
|
||||
def _open(url: str) -> Connection:
|
||||
return connect(url, autocommit=False)
|
||||
|
||||
@@ -76,7 +93,7 @@ def _legacy_user_ids_skipped_for_phone(
|
||||
if owner is not None and owner != r["id"]:
|
||||
skipped.add(r["id"])
|
||||
logger.warning(
|
||||
"skip legacy user {} phone={} (target user id={})",
|
||||
"skip legacy user %s phone=%s (target user id=%s)",
|
||||
r["id"],
|
||||
r["phone"],
|
||||
owner,
|
||||
@@ -678,11 +695,27 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Connect and print row counts only; no writes.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--db-host",
|
||||
metavar="HOST",
|
||||
default=None,
|
||||
help=(
|
||||
"Replace hostname in both URLs (e.g. 127.0.0.1 when running on the host "
|
||||
"while Postgres is published from Docker; keeps user/password/port/dbname)."
|
||||
),
|
||||
)
|
||||
args = p.parse_args()
|
||||
on_conflict: OnConflict = args.on_conflict # type: ignore[assignment]
|
||||
|
||||
legacy = _open(args.legacy_url)
|
||||
target = _open(args.target_url)
|
||||
legacy_url = args.legacy_url
|
||||
target_url = args.target_url
|
||||
if args.db_host:
|
||||
legacy_url = _replace_url_host(legacy_url, args.db_host)
|
||||
target_url = _replace_url_host(target_url, args.db_host)
|
||||
logger.info("using db host %s (from --db-host)", args.db_host)
|
||||
|
||||
legacy = _open(legacy_url)
|
||||
target = _open(target_url)
|
||||
try:
|
||||
if args.dry_run:
|
||||
with legacy.cursor(row_factory=dict_row) as cur:
|
||||
@@ -699,14 +732,14 @@ def main() -> None:
|
||||
):
|
||||
cur.execute(f"SELECT COUNT(*) AS c FROM {t}")
|
||||
c = cur.fetchone()["c"]
|
||||
logger.info("legacy {} rows={}", t, c)
|
||||
logger.info("legacy %s rows=%s", t, c)
|
||||
logger.info("dry-run done")
|
||||
return
|
||||
|
||||
skip_users = _legacy_user_ids_skipped_for_phone(legacy, target)
|
||||
if skip_users:
|
||||
logger.info(
|
||||
"skip {} legacy users due to phone already owned in target",
|
||||
"skip %d legacy users due to phone already owned in target",
|
||||
len(skip_users),
|
||||
)
|
||||
|
||||
@@ -724,9 +757,9 @@ def main() -> None:
|
||||
target.commit()
|
||||
|
||||
logger.info(
|
||||
"migration committed: users={} books={} memoir_states={} "
|
||||
"conversations={} segments={} orders={} refresh_tokens={} "
|
||||
"sms={} chapters={} memoir_images={}",
|
||||
"migration committed: users=%s books=%s memoir_states=%s "
|
||||
"conversations=%s segments=%s orders=%s refresh_tokens=%s "
|
||||
"sms=%s chapters=%s memoir_images=%s",
|
||||
n_users,
|
||||
n_books,
|
||||
n_memoir,
|
||||
|
||||
Reference in New Issue
Block a user