fix(conversation): 离屏不丢回复、列表预热 WS 与非阻塞进入聊天
- 后端:文本/转写后 AI 生成改为独立任务,避免断连取消整轮;按需 TTS 等与 WS 改动 - 前端:RealtimeSession 重绑 UI 时恢复流式 buffer;列表 onPressIn/挂载预热、已有会话立即 push - 同步会话相关类型、i18n、测试与 env/资源等累计改动 Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -18,9 +18,13 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat import ChatOrchestrator
|
||||
from app.agents.chat.reply_limits import segments_from_llm_response
|
||||
from app.core.agent_logging import agent_summary_enabled
|
||||
from app.core.config import settings
|
||||
from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC
|
||||
from app.core.cos_url_keys import (
|
||||
TTS_PRESIGNED_EXPIRES_SEC,
|
||||
extract_cos_object_key_if_owned,
|
||||
)
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
|
||||
from app.features.conversation.chat_turn import (
|
||||
@@ -33,7 +37,7 @@ from app.features.conversation.history_store import (
|
||||
ConversationHistoryStore,
|
||||
)
|
||||
from app.features.conversation.lineage_schemas import DialogueLineage
|
||||
from app.features.conversation.models import Conversation, Segment
|
||||
from app.features.conversation.models import Conversation, ConversationMessage, Segment
|
||||
from app.features.conversation.ws.connection_manager import manager
|
||||
from app.features.conversation.ws.message_types import MessageType
|
||||
from app.features.conversation.ws.profile_collector import (
|
||||
@@ -84,6 +88,7 @@ async def _send_tts_audio(
|
||||
chunk_total: int,
|
||||
assistant_message_id: str | None,
|
||||
tts_epoch_start: int,
|
||||
manual: bool = False,
|
||||
) -> str | None:
|
||||
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
|
||||
if not settings.enable_tts:
|
||||
@@ -116,6 +121,8 @@ async def _send_tts_audio(
|
||||
}
|
||||
if assistant_message_id:
|
||||
payload_data["assistant_message_id"] = assistant_message_id
|
||||
if manual:
|
||||
payload_data["manual"] = True
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
@@ -138,6 +145,109 @@ async def _send_tts_audio(
|
||||
return None
|
||||
|
||||
|
||||
async def handle_tts_request_on_demand(
|
||||
*,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
assistant_message_id: str,
|
||||
segment_index: int,
|
||||
segment_text: str | None,
|
||||
db: AsyncSession,
|
||||
) -> tuple[bool, str]:
|
||||
"""用户点喇叭:该段已有 TTS 则预签名下发;否则合成后落库并下发。不重复合成同一段。"""
|
||||
if not settings.enable_tts:
|
||||
return False, "未开启语音合成"
|
||||
|
||||
conv = await db.get(Conversation, conversation_id)
|
||||
if not conv or conv.user_id != user_id or conv.deleted_at is not None:
|
||||
return False, "对话不存在或无权访问"
|
||||
|
||||
msg = await db.get(ConversationMessage, assistant_message_id)
|
||||
if not msg or msg.conversation_id != conversation_id or msg.role != "ai":
|
||||
return False, "消息不存在"
|
||||
|
||||
# 与客户端 splitMessageParts / segments_from_llm_response 对齐(含无 [SPLIT] 时的段落拆段)
|
||||
parts = segments_from_llm_response(msg.content or "", max_segments=3)
|
||||
if segment_index < 0 or segment_index >= len(parts):
|
||||
return False, "分段序号无效"
|
||||
|
||||
canon = (parts[segment_index] or "").strip()
|
||||
if not canon:
|
||||
return False, "该段无朗读文本"
|
||||
if segment_text and segment_text.strip() and segment_text.strip() != canon:
|
||||
logger.debug(
|
||||
"按需 TTS: 客户端传入 segment_text 与规范化后 canon 不完全一致,已按 segment_index 朗读 canon "
|
||||
"(client_len={} canon_len={})",
|
||||
len(segment_text.strip()),
|
||||
len(canon),
|
||||
)
|
||||
|
||||
urls: List[str] = []
|
||||
for x in msg.tts_audio_urls or []:
|
||||
if isinstance(x, str) and x.strip():
|
||||
urls.append(x)
|
||||
else:
|
||||
urls.append("")
|
||||
while len(urls) < len(parts):
|
||||
urls.append("")
|
||||
|
||||
existing = urls[segment_index].strip() if segment_index < len(urls) else ""
|
||||
chunk_total = len(parts)
|
||||
|
||||
if existing:
|
||||
storage = get_object_storage()
|
||||
key = extract_cos_object_key_if_owned(existing)
|
||||
try:
|
||||
playback_url = (
|
||||
storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
|
||||
if key
|
||||
else existing
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("按需 TTS 预签名失败: {}", exc)
|
||||
playback_url = existing
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.TTS_AUDIO,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"audio_url": playback_url,
|
||||
"format": settings.tts_codec,
|
||||
"index": segment_index,
|
||||
"total": chunk_total,
|
||||
"assistant_message_id": assistant_message_id,
|
||||
"manual": True,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
return True, ""
|
||||
|
||||
tts_epoch_start = _tts_epoch_value(conversation_id)
|
||||
url_stored = await _send_tts_audio(
|
||||
conversation_id,
|
||||
canon,
|
||||
chunk_index=segment_index,
|
||||
chunk_total=chunk_total,
|
||||
assistant_message_id=assistant_message_id,
|
||||
tts_epoch_start=tts_epoch_start,
|
||||
manual=True,
|
||||
)
|
||||
if not url_stored:
|
||||
return False, "语音合成失败"
|
||||
|
||||
while len(urls) <= segment_index:
|
||||
urls.append("")
|
||||
urls[segment_index] = url_stored
|
||||
msg.tts_audio_urls = urls
|
||||
await db.commit()
|
||||
|
||||
store = ConversationHistoryStore(db)
|
||||
await store._sync_redis_best_effort(conversation_id)
|
||||
return True, ""
|
||||
|
||||
|
||||
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
|
||||
chat_orchestrator = ChatOrchestrator()
|
||||
chat_turn_service = ChatTurnService(chat_orchestrator)
|
||||
@@ -153,6 +263,8 @@ class SegmentStreamState:
|
||||
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
|
||||
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
#: 本条语音会话最近一次分段上行携带的本轮朗读开关(客户端每段一致即可)
|
||||
tts_this_turn: bool = False
|
||||
pending_indices: Set[int] = field(default_factory=set)
|
||||
processed_indices: Set[int] = field(default_factory=set)
|
||||
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
|
||||
@@ -163,6 +275,43 @@ class SegmentStreamState:
|
||||
|
||||
|
||||
_segment_states: Dict[Tuple[str, str], SegmentStreamState] = {}
|
||||
_user_response_tasks: Dict[str, Set[asyncio.Task]] = {}
|
||||
_user_response_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _get_user_response_lock(conversation_id: str) -> asyncio.Lock:
|
||||
lock = _user_response_locks.get(conversation_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_user_response_locks[conversation_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
def register_user_response_task(conversation_id: str, task: asyncio.Task) -> None:
|
||||
tasks = _user_response_tasks.setdefault(conversation_id, set())
|
||||
tasks.add(task)
|
||||
|
||||
def _cleanup(done_task: asyncio.Task) -> None:
|
||||
tasks.discard(done_task)
|
||||
if not tasks:
|
||||
_user_response_tasks.pop(conversation_id, None)
|
||||
_user_response_locks.pop(conversation_id, None)
|
||||
if done_task.cancelled():
|
||||
logger.warning(
|
||||
"用户回复后台任务被取消 conversation_id={}",
|
||||
conversation_id,
|
||||
)
|
||||
return
|
||||
exc = done_task.exception()
|
||||
if exc:
|
||||
logger.error(
|
||||
"用户回复后台任务异常 conversation_id={}: {}",
|
||||
conversation_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
task.add_done_callback(_cleanup)
|
||||
|
||||
|
||||
def get_or_create_segment_state(
|
||||
@@ -432,9 +581,13 @@ async def process_audio_segment(
|
||||
audio_base64: str,
|
||||
audio_duration: int,
|
||||
is_last: bool,
|
||||
*,
|
||||
tts_this_turn: bool = False,
|
||||
) -> None:
|
||||
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
|
||||
state = get_or_create_segment_state(conversation_id, voice_session_id)
|
||||
async with state.lock:
|
||||
state.tts_this_turn = bool(tts_this_turn)
|
||||
logger.info(
|
||||
"process_audio_segment 开始: conversation_id={} voice_session_id={} "
|
||||
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
|
||||
@@ -588,6 +741,7 @@ async def process_audio_segment(
|
||||
)
|
||||
|
||||
ready_segments: List[Tuple[int, str, Segment]] = []
|
||||
tts_flag_this_voice_session = False
|
||||
async with state.lock:
|
||||
state.processed_indices.add(segment_index)
|
||||
state.buffered_transcripts[segment_index] = (
|
||||
@@ -602,6 +756,8 @@ async def process_audio_segment(
|
||||
state.consumed_index = next_index
|
||||
next_index += 1
|
||||
|
||||
tts_flag_this_voice_session = bool(state.tts_this_turn)
|
||||
|
||||
for _, ordered_text, ordered_segment in ready_segments:
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
@@ -612,6 +768,7 @@ async def process_audio_segment(
|
||||
user=user,
|
||||
user_message_timestamp=ordered_segment.created_at
|
||||
or user_message_timestamp,
|
||||
tts_this_turn=tts_flag_this_voice_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -638,6 +795,48 @@ async def process_audio_segment(
|
||||
# ── 用户消息处理 ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def process_persisted_user_segment_response(
|
||||
*,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
segment_id: str,
|
||||
tts_this_turn: bool = False,
|
||||
) -> None:
|
||||
"""后台继续生成已落库用户段落的助手回复;即使 WS 页面退出也要完成落库。"""
|
||||
lock = _get_user_response_lock(conversation_id)
|
||||
async with lock:
|
||||
async with AsyncSessionLocal() as db:
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
user = await db.get(User, user_id)
|
||||
segment = await db.get(Segment, segment_id)
|
||||
if (
|
||||
not conversation
|
||||
or conversation.deleted_at is not None
|
||||
or conversation.user_id != user_id
|
||||
or not user
|
||||
or not segment
|
||||
or segment.conversation_id != conversation_id
|
||||
):
|
||||
logger.warning(
|
||||
"跳过用户回复后台任务: conversation_id={} segment_id={} user_id={}",
|
||||
conversation_id,
|
||||
segment_id,
|
||||
user_id,
|
||||
)
|
||||
return
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=segment.user_input_text or "",
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
user=user,
|
||||
user_message_timestamp=segment.created_at
|
||||
or conversation.last_message_at,
|
||||
tts_this_turn=tts_this_turn,
|
||||
)
|
||||
|
||||
|
||||
async def process_user_message(
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
@@ -648,6 +847,7 @@ async def process_user_message(
|
||||
user_message_timestamp: Optional[datetime] = None,
|
||||
*,
|
||||
force_skip_tts: bool = False,
|
||||
tts_this_turn: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。"""
|
||||
store = ConversationHistoryStore(db)
|
||||
@@ -682,20 +882,23 @@ async def process_user_message(
|
||||
get_filled_profile_fields_fn=get_filled_profile_fields,
|
||||
),
|
||||
)
|
||||
responses = turn.messages
|
||||
skip_tts = bool(turn.skip_tts)
|
||||
want_voice = bool(tts_this_turn) if tts_this_turn is not None else False
|
||||
want_tts = want_voice and settings.enable_tts and not skip_tts
|
||||
if agent_summary_enabled():
|
||||
logger.info(
|
||||
"pipeline.process_user_message duration_ms={:.2f} "
|
||||
"conversation_id={} segment_id={} user_msg_len={} "
|
||||
"response_segments={} skip_tts={}",
|
||||
"response_segments={} skip_tts={} want_tts={}",
|
||||
(time.perf_counter() - t_pipeline) * 1000,
|
||||
conversation_id,
|
||||
segment.id,
|
||||
len(user_message or ""),
|
||||
len(turn.messages),
|
||||
turn.skip_tts,
|
||||
want_tts,
|
||||
)
|
||||
responses = turn.messages
|
||||
skip_tts = bool(turn.skip_tts)
|
||||
|
||||
segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses)
|
||||
_mark_conversation_active(conversation)
|
||||
@@ -750,6 +953,21 @@ async def process_user_message(
|
||||
tts_epoch_start = _tts_epoch_value(conversation_id)
|
||||
n = len(responses)
|
||||
for i, response_text in enumerate(responses):
|
||||
url_for_segment: Optional[str] = None
|
||||
if want_tts:
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
break
|
||||
url_for_segment = await _send_tts_audio(
|
||||
conversation_id,
|
||||
response_text,
|
||||
chunk_index=i,
|
||||
chunk_total=n,
|
||||
assistant_message_id=ai_msg_id,
|
||||
tts_epoch_start=tts_epoch_start,
|
||||
)
|
||||
if url_for_segment:
|
||||
tts_urls.append(url_for_segment)
|
||||
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
@@ -764,20 +982,7 @@ async def process_user_message(
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
url = None
|
||||
if not skip_tts:
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
break
|
||||
url = await _send_tts_audio(
|
||||
conversation_id,
|
||||
response_text,
|
||||
chunk_index=i,
|
||||
chunk_total=n,
|
||||
assistant_message_id=ai_msg_id,
|
||||
tts_epoch_start=tts_epoch_start,
|
||||
)
|
||||
if url:
|
||||
tts_urls.append(url)
|
||||
|
||||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||||
break
|
||||
if i < n - 1:
|
||||
|
||||
Reference in New Issue
Block a user