fix(conversation): 离屏不丢回复、列表预热 WS 与非阻塞进入聊天

- 后端:文本/转写后 AI 生成改为独立任务,避免断连取消整轮;按需 TTS 等与 WS 改动
- 前端:RealtimeSession 重绑 UI 时恢复流式 buffer;列表 onPressIn/挂载预热、已有会话立即 push
- 同步会话相关类型、i18n、测试与 env/资源等累计改动

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Kevin
2026-05-08 17:28:31 +08:00
parent 5dac3efd52
commit d0c26242db
44 changed files with 1209 additions and 212 deletions

View File

@@ -18,9 +18,13 @@ from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat import ChatOrchestrator
from app.agents.chat.reply_limits import segments_from_llm_response
from app.core.agent_logging import agent_summary_enabled
from app.core.config import settings
from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC
from app.core.cos_url_keys import (
TTS_PRESIGNED_EXPIRES_SEC,
extract_cos_object_key_if_owned,
)
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
from app.features.conversation.chat_turn import (
@@ -33,7 +37,7 @@ from app.features.conversation.history_store import (
ConversationHistoryStore,
)
from app.features.conversation.lineage_schemas import DialogueLineage
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.models import Conversation, ConversationMessage, Segment
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import MessageType
from app.features.conversation.ws.profile_collector import (
@@ -84,6 +88,7 @@ async def _send_tts_audio(
chunk_total: int,
assistant_message_id: str | None,
tts_epoch_start: int,
manual: bool = False,
) -> str | None:
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
if not settings.enable_tts:
@@ -116,6 +121,8 @@ async def _send_tts_audio(
}
if assistant_message_id:
payload_data["assistant_message_id"] = assistant_message_id
if manual:
payload_data["manual"] = True
await manager.send_message(
conversation_id,
{
@@ -138,6 +145,109 @@ async def _send_tts_audio(
return None
async def handle_tts_request_on_demand(
*,
conversation_id: str,
user_id: str,
assistant_message_id: str,
segment_index: int,
segment_text: str | None,
db: AsyncSession,
) -> tuple[bool, str]:
"""用户点喇叭:该段已有 TTS 则预签名下发;否则合成后落库并下发。不重复合成同一段。"""
if not settings.enable_tts:
return False, "未开启语音合成"
conv = await db.get(Conversation, conversation_id)
if not conv or conv.user_id != user_id or conv.deleted_at is not None:
return False, "对话不存在或无权访问"
msg = await db.get(ConversationMessage, assistant_message_id)
if not msg or msg.conversation_id != conversation_id or msg.role != "ai":
return False, "消息不存在"
# 与客户端 splitMessageParts / segments_from_llm_response 对齐(含无 [SPLIT] 时的段落拆段)
parts = segments_from_llm_response(msg.content or "", max_segments=3)
if segment_index < 0 or segment_index >= len(parts):
return False, "分段序号无效"
canon = (parts[segment_index] or "").strip()
if not canon:
return False, "该段无朗读文本"
if segment_text and segment_text.strip() and segment_text.strip() != canon:
logger.debug(
"按需 TTS: 客户端传入 segment_text 与规范化后 canon 不完全一致,已按 segment_index 朗读 canon "
"(client_len={} canon_len={})",
len(segment_text.strip()),
len(canon),
)
urls: List[str] = []
for x in msg.tts_audio_urls or []:
if isinstance(x, str) and x.strip():
urls.append(x)
else:
urls.append("")
while len(urls) < len(parts):
urls.append("")
existing = urls[segment_index].strip() if segment_index < len(urls) else ""
chunk_total = len(parts)
if existing:
storage = get_object_storage()
key = extract_cos_object_key_if_owned(existing)
try:
playback_url = (
storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
if key
else existing
)
except Exception as exc:
logger.warning("按需 TTS 预签名失败: {}", exc)
playback_url = existing
await manager.send_message(
conversation_id,
{
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {
"audio_url": playback_url,
"format": settings.tts_codec,
"index": segment_index,
"total": chunk_total,
"assistant_message_id": assistant_message_id,
"manual": True,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return True, ""
tts_epoch_start = _tts_epoch_value(conversation_id)
url_stored = await _send_tts_audio(
conversation_id,
canon,
chunk_index=segment_index,
chunk_total=chunk_total,
assistant_message_id=assistant_message_id,
tts_epoch_start=tts_epoch_start,
manual=True,
)
if not url_stored:
return False, "语音合成失败"
while len(urls) <= segment_index:
urls.append("")
urls[segment_index] = url_stored
msg.tts_audio_urls = urls
await db.commit()
store = ConversationHistoryStore(db)
await store._sync_redis_best_effort(conversation_id)
return True, ""
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
chat_orchestrator = ChatOrchestrator()
chat_turn_service = ChatTurnService(chat_orchestrator)
@@ -153,6 +263,8 @@ class SegmentStreamState:
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
#: 本条语音会话最近一次分段上行携带的本轮朗读开关(客户端每段一致即可)
tts_this_turn: bool = False
pending_indices: Set[int] = field(default_factory=set)
processed_indices: Set[int] = field(default_factory=set)
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
@@ -163,6 +275,43 @@ class SegmentStreamState:
_segment_states: Dict[Tuple[str, str], SegmentStreamState] = {}
_user_response_tasks: Dict[str, Set[asyncio.Task]] = {}
_user_response_locks: Dict[str, asyncio.Lock] = {}
def _get_user_response_lock(conversation_id: str) -> asyncio.Lock:
lock = _user_response_locks.get(conversation_id)
if lock is None:
lock = asyncio.Lock()
_user_response_locks[conversation_id] = lock
return lock
def register_user_response_task(conversation_id: str, task: asyncio.Task) -> None:
tasks = _user_response_tasks.setdefault(conversation_id, set())
tasks.add(task)
def _cleanup(done_task: asyncio.Task) -> None:
tasks.discard(done_task)
if not tasks:
_user_response_tasks.pop(conversation_id, None)
_user_response_locks.pop(conversation_id, None)
if done_task.cancelled():
logger.warning(
"用户回复后台任务被取消 conversation_id={}",
conversation_id,
)
return
exc = done_task.exception()
if exc:
logger.error(
"用户回复后台任务异常 conversation_id={}: {}",
conversation_id,
exc,
exc_info=True,
)
task.add_done_callback(_cleanup)
def get_or_create_segment_state(
@@ -432,9 +581,13 @@ async def process_audio_segment(
audio_base64: str,
audio_duration: int,
is_last: bool,
*,
tts_this_turn: bool = False,
) -> None:
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
state = get_or_create_segment_state(conversation_id, voice_session_id)
async with state.lock:
state.tts_this_turn = bool(tts_this_turn)
logger.info(
"process_audio_segment 开始: conversation_id={} voice_session_id={} "
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
@@ -588,6 +741,7 @@ async def process_audio_segment(
)
ready_segments: List[Tuple[int, str, Segment]] = []
tts_flag_this_voice_session = False
async with state.lock:
state.processed_indices.add(segment_index)
state.buffered_transcripts[segment_index] = (
@@ -602,6 +756,8 @@ async def process_audio_segment(
state.consumed_index = next_index
next_index += 1
tts_flag_this_voice_session = bool(state.tts_this_turn)
for _, ordered_text, ordered_segment in ready_segments:
await process_user_message(
conversation_id=conversation_id,
@@ -612,6 +768,7 @@ async def process_audio_segment(
user=user,
user_message_timestamp=ordered_segment.created_at
or user_message_timestamp,
tts_this_turn=tts_flag_this_voice_session,
)
except Exception as e:
@@ -638,6 +795,48 @@ async def process_audio_segment(
# ── 用户消息处理 ────────────────────────────────────────────────
async def process_persisted_user_segment_response(
*,
conversation_id: str,
user_id: str,
segment_id: str,
tts_this_turn: bool = False,
) -> None:
"""后台继续生成已落库用户段落的助手回复;即使 WS 页面退出也要完成落库。"""
lock = _get_user_response_lock(conversation_id)
async with lock:
async with AsyncSessionLocal() as db:
conversation = await db.get(Conversation, conversation_id)
user = await db.get(User, user_id)
segment = await db.get(Segment, segment_id)
if (
not conversation
or conversation.deleted_at is not None
or conversation.user_id != user_id
or not user
or not segment
or segment.conversation_id != conversation_id
):
logger.warning(
"跳过用户回复后台任务: conversation_id={} segment_id={} user_id={}",
conversation_id,
segment_id,
user_id,
)
return
await process_user_message(
conversation_id=conversation_id,
user_message=segment.user_input_text or "",
conversation=conversation,
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at
or conversation.last_message_at,
tts_this_turn=tts_this_turn,
)
async def process_user_message(
conversation_id: str,
user_message: str,
@@ -648,6 +847,7 @@ async def process_user_message(
user_message_timestamp: Optional[datetime] = None,
*,
force_skip_tts: bool = False,
tts_this_turn: Optional[bool] = None,
) -> None:
"""处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。"""
store = ConversationHistoryStore(db)
@@ -682,20 +882,23 @@ async def process_user_message(
get_filled_profile_fields_fn=get_filled_profile_fields,
),
)
responses = turn.messages
skip_tts = bool(turn.skip_tts)
want_voice = bool(tts_this_turn) if tts_this_turn is not None else False
want_tts = want_voice and settings.enable_tts and not skip_tts
if agent_summary_enabled():
logger.info(
"pipeline.process_user_message duration_ms={:.2f} "
"conversation_id={} segment_id={} user_msg_len={} "
"response_segments={} skip_tts={}",
"response_segments={} skip_tts={} want_tts={}",
(time.perf_counter() - t_pipeline) * 1000,
conversation_id,
segment.id,
len(user_message or ""),
len(turn.messages),
turn.skip_tts,
want_tts,
)
responses = turn.messages
skip_tts = bool(turn.skip_tts)
segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses)
_mark_conversation_active(conversation)
@@ -750,6 +953,21 @@ async def process_user_message(
tts_epoch_start = _tts_epoch_value(conversation_id)
n = len(responses)
for i, response_text in enumerate(responses):
url_for_segment: Optional[str] = None
if want_tts:
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
url_for_segment = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
assistant_message_id=ai_msg_id,
tts_epoch_start=tts_epoch_start,
)
if url_for_segment:
tts_urls.append(url_for_segment)
await manager.send_message(
conversation_id,
{
@@ -764,20 +982,7 @@ async def process_user_message(
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
url = None
if not skip_tts:
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
url = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
assistant_message_id=ai_msg_id,
tts_epoch_start=tts_epoch_start,
)
if url:
tts_urls.append(url)
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
if i < n - 1: