Replace CreateRecTask polling with recording-file flash API, add TENCENT_APP_ID, remove server-side pydub slicing, and log ASR recognition text at INFO in development. Co-authored-by: Cursor <cursoragent@cursor.com>
1215 lines
44 KiB
Python
1215 lines
44 KiB
Python
"""核心消息处理管道:Agent 调用、ASR 转写、分段有序聚合"""
|
||
|
||
import asyncio
|
||
import base64
|
||
import time
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||
|
||
from app.core.logging import get_logger
|
||
|
||
if TYPE_CHECKING:
|
||
from app.features.quota.service import QuotaService
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.agents.chat import ChatOrchestrator
|
||
from app.agents.chat.reply_limits import segments_from_llm_response
|
||
from app.core.agent_logging import agent_summary_enabled, log_asr_transcript_result
|
||
from app.core.business_telemetry import business_span
|
||
from app.core.config import settings
|
||
from app.core.cos_url_keys import (
|
||
TTS_PRESIGNED_EXPIRES_SEC,
|
||
extract_cos_object_key_if_owned,
|
||
)
|
||
from app.core.db import AsyncSessionLocal
|
||
from app.features.conversation.ws.persist import (
|
||
persist_message_tts_url_segment,
|
||
persist_voice_segment_row,
|
||
)
|
||
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
|
||
from app.features.conversation.chat_turn import (
|
||
ChatTurnContext,
|
||
ChatTurnInput,
|
||
ChatTurnService,
|
||
)
|
||
from app.features.conversation.history_store import (
|
||
AI_RESPONSE_SEGMENT_JOIN,
|
||
ConversationHistoryStore,
|
||
)
|
||
from app.features.conversation.models import Conversation, ConversationMessage, Segment
|
||
from app.features.conversation.ws.connection_manager import manager
|
||
from app.features.conversation.ws.message_types import MessageType
|
||
from app.features.conversation.ws.profile_collector import (
|
||
apply_extracted_profile,
|
||
get_filled_profile_fields,
|
||
get_missing_profile_fields,
|
||
)
|
||
from app.features.conversation.ws.topic_chips_push import maybe_send_topic_chips_ws
|
||
from app.features.memoir.background_runner import BackgroundTaskRunner
|
||
from app.features.memoir.ingest_scheduler import MemoirIngestScheduler, MemoirTrigger
|
||
from app.features.memoir.state_service import get_or_create_state
|
||
from app.features.user.models import User
|
||
from app.ports.asr import ASRTranscriptionError
|
||
from app.core.runtime_constants import tts_defaults
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 客户端发送 tts_cancel 时递增;process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段
|
||
_tts_cancel_epoch: dict[str, int] = {}
|
||
|
||
|
||
def bump_tts_cancel_epoch(conversation_id: str) -> None:
|
||
_tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1
|
||
|
||
|
||
def _tts_epoch_value(conversation_id: str) -> int:
|
||
return _tts_cancel_epoch.get(conversation_id, 0)
|
||
|
||
|
||
def _resolve_user_language(user) -> str:
|
||
"""Return 'en' iff user.language_preference is set to 'en'; default 'zh'."""
|
||
raw = getattr(user, "language_preference", "zh") if user is not None else "zh"
|
||
return "en" if str(raw or "zh").strip().lower() == "en" else "zh"
|
||
|
||
|
||
def _tts_object_ext(codec: str) -> str:
|
||
c = (codec or "mp3").lower().lstrip(".")
|
||
if c in ("wave",):
|
||
return "wav"
|
||
return c if c else "mp3"
|
||
|
||
|
||
def _tts_codec_to_content_type(codec: str) -> str:
|
||
c = (codec or "mp3").lower().lstrip(".")
|
||
if c == "mp3":
|
||
return "audio/mpeg"
|
||
if c in ("wav", "wave"):
|
||
return "audio/wav"
|
||
return "application/octet-stream"
|
||
|
||
|
||
async def _send_tts_audio(
|
||
conversation_id: str,
|
||
text: str,
|
||
*,
|
||
chunk_index: int,
|
||
chunk_total: int,
|
||
assistant_message_id: str | None,
|
||
tts_epoch_start: int,
|
||
manual: bool = False,
|
||
language: str = "zh",
|
||
) -> str | None:
|
||
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
|
||
current_epoch = _tts_epoch_value(conversation_id)
|
||
# enable_tts:仅禁用「助手回复自动生成 TTS」(want_tts 路径);用户点喇叭(manual=True)仍可合成。
|
||
if not manual and not settings.enable_tts:
|
||
return None
|
||
if current_epoch != tts_epoch_start:
|
||
return None
|
||
try:
|
||
tts = get_tts_provider()
|
||
audio_bytes = await tts.synthesize(text, language=language)
|
||
if not audio_bytes:
|
||
logger.warning(
|
||
"TTS skipped: synthesize returned empty conversation_id={} chunk_index={} "
|
||
"language={} text_preview={!r} voice_provider={}",
|
||
conversation_id,
|
||
chunk_index,
|
||
language,
|
||
(text or "")[:30],
|
||
tts_defaults.provider,
|
||
)
|
||
return None
|
||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||
return None
|
||
ext = _tts_object_ext(tts_defaults.codec)
|
||
content_type = _tts_codec_to_content_type(tts_defaults.codec)
|
||
storage = get_object_storage()
|
||
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
|
||
public_url = storage.upload(key, audio_bytes, content_type)
|
||
# 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL
|
||
playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
|
||
payload_data: Dict[str, Any] = {
|
||
"format": tts_defaults.codec,
|
||
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
||
"audio_url": playback_url,
|
||
"index": chunk_index,
|
||
"total": chunk_total,
|
||
}
|
||
if assistant_message_id:
|
||
payload_data["assistant_message_id"] = assistant_message_id
|
||
if manual:
|
||
payload_data["manual"] = True
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.TTS_AUDIO,
|
||
"conversation_id": conversation_id,
|
||
"data": payload_data,
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
return public_url
|
||
except Exception as e:
|
||
err_str = str(e)
|
||
if "PkgExhausted" in err_str:
|
||
logger.warning(
|
||
"TTS skipped: 腾讯云语音合成资源包已用尽,请在控制台购买或开通后付费: {}",
|
||
err_str[:100],
|
||
)
|
||
else:
|
||
logger.error("TTS synthesize failed: {}", e)
|
||
return None
|
||
|
||
|
||
async def handle_tts_request_on_demand(
|
||
*,
|
||
conversation_id: str,
|
||
user_id: str,
|
||
assistant_message_id: str,
|
||
segment_index: int,
|
||
segment_text: str | None,
|
||
db: AsyncSession,
|
||
) -> tuple[bool, str]:
|
||
"""用户点喇叭:该段已有 TTS 则预签名下发;否则合成后落库并下发。不重复合成同一段。"""
|
||
logger.info(
|
||
"pipeline.handle_tts_request_on_demand entry conversation_id={} user_id={} "
|
||
"assistant_message_id={} segment_index={} segment_text_len={} enable_tts={} provider={}",
|
||
conversation_id,
|
||
user_id,
|
||
assistant_message_id,
|
||
segment_index,
|
||
len(segment_text or ""),
|
||
settings.enable_tts,
|
||
tts_defaults.provider,
|
||
)
|
||
|
||
conv = await db.get(Conversation, conversation_id)
|
||
if not conv or conv.user_id != user_id or conv.deleted_at is not None:
|
||
logger.debug(
|
||
"pipeline.handle_tts_request_on_demand result ok=False reason=对话不存在或无权访问 "
|
||
"conversation_id={} user_id={}",
|
||
conversation_id,
|
||
user_id,
|
||
)
|
||
return False, "对话不存在或无权访问"
|
||
|
||
msg = await db.get(ConversationMessage, assistant_message_id)
|
||
if not msg or msg.conversation_id != conversation_id or msg.role != "ai":
|
||
logger.debug(
|
||
"pipeline.handle_tts_request_on_demand result ok=False reason=消息不存在 "
|
||
"conversation_id={} assistant_message_id={}",
|
||
conversation_id,
|
||
assistant_message_id,
|
||
)
|
||
return False, "消息不存在"
|
||
|
||
# 与客户端 splitMessageParts / segments_from_llm_response 对齐(含无 [SPLIT] 时的段落拆段)
|
||
parts = segments_from_llm_response(msg.content or "", max_segments=3)
|
||
if segment_index < 0 or segment_index >= len(parts):
|
||
return False, "分段序号无效"
|
||
|
||
canon = (parts[segment_index] or "").strip()
|
||
if not canon:
|
||
return False, "该段无朗读文本"
|
||
if segment_text and segment_text.strip() and segment_text.strip() != canon:
|
||
logger.debug(
|
||
"按需 TTS: 客户端传入 segment_text 与规范化后 canon 不完全一致,已按 segment_index 朗读 canon "
|
||
"(client_len={} canon_len={})",
|
||
len(segment_text.strip()),
|
||
len(canon),
|
||
)
|
||
|
||
urls: List[str] = []
|
||
for x in msg.tts_audio_urls or []:
|
||
if isinstance(x, str) and x.strip():
|
||
urls.append(x)
|
||
else:
|
||
urls.append("")
|
||
while len(urls) < len(parts):
|
||
urls.append("")
|
||
|
||
existing = urls[segment_index].strip() if segment_index < len(urls) else ""
|
||
chunk_total = len(parts)
|
||
|
||
if existing:
|
||
logger.info(
|
||
"pipeline.handle_tts_request_on_demand reuse existing url conversation_id={} "
|
||
"assistant_message_id={} segment_index={} url_len={}",
|
||
conversation_id,
|
||
assistant_message_id,
|
||
segment_index,
|
||
len(existing),
|
||
)
|
||
storage = get_object_storage()
|
||
key = extract_cos_object_key_if_owned(existing)
|
||
try:
|
||
playback_url = (
|
||
storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
|
||
if key
|
||
else existing
|
||
)
|
||
except Exception as exc:
|
||
logger.warning("按需 TTS 预签名失败: {}", exc)
|
||
playback_url = existing
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.TTS_AUDIO,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"audio_url": playback_url,
|
||
"format": tts_defaults.codec,
|
||
"index": segment_index,
|
||
"total": chunk_total,
|
||
"assistant_message_id": assistant_message_id,
|
||
"manual": True,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
logger.info(
|
||
"pipeline.handle_tts_request_on_demand result ok=True reason=existing_reused "
|
||
"conversation_id={} assistant_message_id={} segment_index={}",
|
||
conversation_id,
|
||
assistant_message_id,
|
||
segment_index,
|
||
)
|
||
return True, ""
|
||
|
||
logger.info(
|
||
"pipeline.handle_tts_request_on_demand no existing url, will synthesize "
|
||
"conversation_id={} assistant_message_id={} segment_index={} canon_len={}",
|
||
conversation_id,
|
||
assistant_message_id,
|
||
segment_index,
|
||
len(canon),
|
||
)
|
||
|
||
user_obj = await db.get(User, user_id)
|
||
user_language = _resolve_user_language(user_obj)
|
||
|
||
tts_epoch_start = _tts_epoch_value(conversation_id)
|
||
url_stored = await _send_tts_audio(
|
||
conversation_id,
|
||
canon,
|
||
chunk_index=segment_index,
|
||
chunk_total=chunk_total,
|
||
assistant_message_id=assistant_message_id,
|
||
tts_epoch_start=tts_epoch_start,
|
||
manual=True,
|
||
language=user_language,
|
||
)
|
||
logger.info(
|
||
"pipeline.handle_tts_request_on_demand _send_tts_audio returned url_stored_set={} "
|
||
"conversation_id={} assistant_message_id={} segment_index={}",
|
||
bool(url_stored),
|
||
conversation_id,
|
||
assistant_message_id,
|
||
segment_index,
|
||
)
|
||
if not url_stored:
|
||
logger.info(
|
||
"pipeline.handle_tts_request_on_demand result ok=False reason=语音合成失败 "
|
||
"conversation_id={} assistant_message_id={} segment_index={}",
|
||
conversation_id,
|
||
assistant_message_id,
|
||
segment_index,
|
||
)
|
||
return False, "语音合成失败"
|
||
|
||
await persist_message_tts_url_segment(db, msg, segment_index, url_stored)
|
||
|
||
store = ConversationHistoryStore(db)
|
||
await store._sync_redis_best_effort(conversation_id)
|
||
logger.info(
|
||
"pipeline.handle_tts_request_on_demand result ok=True reason=synthesized "
|
||
"conversation_id={} assistant_message_id={} segment_index={}",
|
||
conversation_id,
|
||
assistant_message_id,
|
||
segment_index,
|
||
)
|
||
return True, ""
|
||
|
||
|
||
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
|
||
chat_orchestrator = ChatOrchestrator()
|
||
chat_turn_service = ChatTurnService(chat_orchestrator)
|
||
_background_runner = BackgroundTaskRunner()
|
||
memoir_ingest_scheduler = MemoirIngestScheduler(_background_runner)
|
||
|
||
|
||
async def _schedule_memoir_ingest_for_segment(
|
||
user_id: str,
|
||
segment: Segment,
|
||
*,
|
||
trigger: MemoirTrigger = "turn",
|
||
) -> None:
|
||
"""Queue memoir phase1 after segment text (and ideally lineage) is durable."""
|
||
text = (segment.user_input_text or "").strip()
|
||
if not text:
|
||
return
|
||
await memoir_ingest_scheduler.queue_segment(
|
||
user_id,
|
||
str(segment.id),
|
||
text_char_count=len(text),
|
||
trigger=trigger,
|
||
)
|
||
|
||
|
||
# ── 分段流状态 ──────────────────────────────────────────────────
|
||
|
||
|
||
@dataclass
|
||
class SegmentStreamState:
|
||
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
|
||
|
||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||
#: 本条语音会话最近一次分段上行携带的本轮朗读开关(客户端每段一致即可)
|
||
tts_this_turn: bool = False
|
||
pending_indices: Set[int] = field(default_factory=set)
|
||
processed_indices: Set[int] = field(default_factory=set)
|
||
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
|
||
consumed_index: int = -1
|
||
active_tasks: Set[asyncio.Task] = field(default_factory=set)
|
||
listening_feedback_sent: bool = False
|
||
listening_feedback_task: Optional[asyncio.Task] = None
|
||
|
||
|
||
_segment_states: Dict[Tuple[str, str], SegmentStreamState] = {}
|
||
_user_response_tasks: Dict[str, Set[asyncio.Task]] = {}
|
||
_user_response_locks: Dict[str, asyncio.Lock] = {}
|
||
|
||
|
||
def _get_user_response_lock(conversation_id: str) -> asyncio.Lock:
|
||
lock = _user_response_locks.get(conversation_id)
|
||
if lock is None:
|
||
lock = asyncio.Lock()
|
||
_user_response_locks[conversation_id] = lock
|
||
return lock
|
||
|
||
|
||
def register_user_response_task(conversation_id: str, task: asyncio.Task) -> None:
|
||
tasks = _user_response_tasks.setdefault(conversation_id, set())
|
||
tasks.add(task)
|
||
|
||
def _cleanup(done_task: asyncio.Task) -> None:
|
||
tasks.discard(done_task)
|
||
if not tasks:
|
||
_user_response_tasks.pop(conversation_id, None)
|
||
_user_response_locks.pop(conversation_id, None)
|
||
if done_task.cancelled():
|
||
logger.warning(
|
||
"用户回复后台任务被取消 conversation_id={}",
|
||
conversation_id,
|
||
)
|
||
return
|
||
exc = done_task.exception()
|
||
if exc:
|
||
logger.error(
|
||
"用户回复后台任务异常 conversation_id={}: {}",
|
||
conversation_id,
|
||
exc,
|
||
exc_info=True,
|
||
)
|
||
|
||
task.add_done_callback(_cleanup)
|
||
|
||
|
||
def get_or_create_segment_state(
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
) -> SegmentStreamState:
|
||
state_key = (conversation_id, voice_session_id)
|
||
if state_key not in _segment_states:
|
||
_segment_states[state_key] = SegmentStreamState()
|
||
return _segment_states[state_key]
|
||
|
||
|
||
def register_segment_task(
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
task: asyncio.Task,
|
||
) -> None:
|
||
state_key = (conversation_id, voice_session_id)
|
||
state = get_or_create_segment_state(conversation_id, voice_session_id)
|
||
state.active_tasks.add(task)
|
||
|
||
def _cleanup(done_task: asyncio.Task) -> None:
|
||
state.active_tasks.discard(done_task)
|
||
if not state.active_tasks and conversation_id not in manager.active_connections:
|
||
_segment_states.pop(state_key, None)
|
||
if done_task.cancelled():
|
||
return
|
||
exc = done_task.exception()
|
||
if exc:
|
||
logger.error(
|
||
"分段处理任务异常 "
|
||
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
|
||
exc_info=True,
|
||
)
|
||
|
||
task.add_done_callback(_cleanup)
|
||
|
||
|
||
def cleanup_segment_states(conversation_id: str) -> None:
|
||
"""断开连接后清理无活跃任务的分段状态"""
|
||
stale_keys = [
|
||
key
|
||
for key, state in _segment_states.items()
|
||
if key[0] == conversation_id and not state.active_tasks
|
||
]
|
||
for key in stale_keys:
|
||
_segment_states.pop(key, None)
|
||
|
||
|
||
# ── 工具函数 ────────────────────────────────────────────────────
|
||
|
||
|
||
def _utc_now() -> datetime:
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def _voice_session_id_from_client_segment_id(
|
||
client_segment_id: Optional[str],
|
||
) -> Optional[str]:
|
||
if not client_segment_id:
|
||
return None
|
||
session_id, separator, _ = client_segment_id.rpartition("-")
|
||
if separator and session_id:
|
||
return session_id
|
||
return None
|
||
|
||
|
||
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
|
||
"""构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。"""
|
||
return f"audio-segment:{voice_session_id}:{segment_index}"
|
||
|
||
|
||
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
|
||
"""从 audio_url 解析 voice_session_id 与 segment_index(audio-segment:{session_id}:{index})。"""
|
||
prefix = "audio-segment:"
|
||
if not audio_url or not audio_url.startswith(prefix):
|
||
return None
|
||
payload = audio_url[len(prefix) :]
|
||
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
|
||
if not separator:
|
||
return None
|
||
try:
|
||
sid = str(voice_session_id_raw).strip()
|
||
if not sid:
|
||
return None
|
||
return (sid, int(segment_index_raw))
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]:
|
||
scope = _extract_segment_scope(audio_url)
|
||
if scope:
|
||
return scope[0]
|
||
return None
|
||
|
||
|
||
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
|
||
if not transcript_text:
|
||
return True
|
||
return transcript_text.startswith("转写失败")
|
||
|
||
|
||
async def _find_existing_segment_by_index(
|
||
db: AsyncSession,
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
segment_index: int,
|
||
) -> Optional[Segment]:
|
||
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
|
||
stmt = (
|
||
select(Segment)
|
||
.where(
|
||
Segment.conversation_id == conversation_id,
|
||
Segment.audio_url == segment_audio_url,
|
||
)
|
||
.order_by(Segment.created_at.desc())
|
||
)
|
||
result = await db.execute(stmt)
|
||
candidates = result.scalars().all()
|
||
for item in candidates:
|
||
if (
|
||
item.conversation_id == conversation_id
|
||
and item.audio_url == segment_audio_url
|
||
):
|
||
return item
|
||
return None
|
||
|
||
|
||
async def _get_persisted_contiguous_segment_index(
|
||
db: AsyncSession,
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
) -> int:
|
||
"""读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。"""
|
||
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
|
||
result = await db.execute(stmt)
|
||
candidates = result.scalars().all()
|
||
|
||
persisted_indices: Set[int] = set()
|
||
for item in candidates:
|
||
if item.conversation_id != conversation_id:
|
||
continue
|
||
segment_scope = _extract_segment_scope(item.audio_url)
|
||
if not segment_scope:
|
||
continue
|
||
item_voice_session_id, item_index = segment_scope
|
||
if item_voice_session_id != voice_session_id:
|
||
continue
|
||
persisted_indices.add(item_index)
|
||
|
||
contiguous_index = -1
|
||
while contiguous_index + 1 in persisted_indices:
|
||
contiguous_index += 1
|
||
return contiguous_index
|
||
|
||
|
||
# ── 过渡反馈 ────────────────────────────────────────────────────
|
||
|
||
LISTENING_FEEDBACK_DELAY_SEC = 5.0
|
||
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
|
||
|
||
|
||
async def _send_segment_transition_feedback(
|
||
conversation_id: str,
|
||
segment_index: int,
|
||
) -> None:
|
||
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": LISTENING_FEEDBACK_TEXT,
|
||
"transition": True,
|
||
"segment_index": segment_index,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
|
||
async def _delayed_listening_feedback(
|
||
conversation_id: str,
|
||
voice_session_id: str,
|
||
) -> None:
|
||
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
|
||
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
|
||
state = get_or_create_segment_state(conversation_id, voice_session_id)
|
||
async with state.lock:
|
||
if state.listening_feedback_sent:
|
||
return
|
||
state.listening_feedback_sent = True
|
||
state.listening_feedback_task = None
|
||
await _send_segment_transition_feedback(conversation_id, 0)
|
||
|
||
|
||
# ── 分段语音异步处理 ────────────────────────────────────────────
|
||
|
||
|
||
async def process_audio_segment(
|
||
conversation_id: str,
|
||
user_id: str,
|
||
voice_session_id: str,
|
||
segment_index: int,
|
||
audio_base64: str,
|
||
audio_duration: int,
|
||
is_last: bool,
|
||
*,
|
||
tts_this_turn: bool = False,
|
||
) -> None:
|
||
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
|
||
state = get_or_create_segment_state(conversation_id, voice_session_id)
|
||
async with state.lock:
|
||
state.tts_this_turn = bool(tts_this_turn)
|
||
logger.info(
|
||
"process_audio_segment 开始: conversation_id={} voice_session_id={} "
|
||
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
|
||
conversation_id,
|
||
voice_session_id,
|
||
segment_index,
|
||
is_last,
|
||
audio_duration,
|
||
len(audio_base64 or ""),
|
||
)
|
||
|
||
try:
|
||
async with AsyncSessionLocal() as db:
|
||
conversation = await db.get(Conversation, conversation_id)
|
||
user = await db.get(User, user_id)
|
||
if not conversation or conversation.deleted_at is not None:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "对话不存在,分段处理已取消"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
return
|
||
if not user:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "用户不存在,分段处理已取消"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
return
|
||
|
||
async with state.lock:
|
||
should_prime_state = (
|
||
state.consumed_index < 0
|
||
and not state.processed_indices
|
||
and not state.buffered_transcripts
|
||
)
|
||
|
||
if should_prime_state:
|
||
persisted_contiguous_index = (
|
||
await _get_persisted_contiguous_segment_index(
|
||
db=db,
|
||
conversation_id=conversation_id,
|
||
voice_session_id=voice_session_id,
|
||
)
|
||
)
|
||
if persisted_contiguous_index >= 0:
|
||
async with state.lock:
|
||
state.consumed_index = max(
|
||
state.consumed_index, persisted_contiguous_index
|
||
)
|
||
|
||
try:
|
||
audio_bytes = base64.b64decode(audio_base64)
|
||
except Exception:
|
||
audio_bytes = b""
|
||
if not audio_bytes:
|
||
logger.warning(
|
||
"process_audio_segment: 解码后音频为空 conversation_id={} segment_index={}",
|
||
conversation_id,
|
||
segment_index,
|
||
)
|
||
try:
|
||
asr = get_asr_provider()
|
||
transcript_text = await asr.transcribe(audio_bytes, format="m4a")
|
||
if transcript_text:
|
||
log_asr_transcript_result(
|
||
logger,
|
||
text=transcript_text,
|
||
conversation_id=conversation_id,
|
||
voice_session_id=voice_session_id,
|
||
segment_index=segment_index,
|
||
duration_s=audio_duration,
|
||
audio_len=len(audio_bytes),
|
||
source="audio_segment",
|
||
)
|
||
except ASRTranscriptionError as e:
|
||
logger.warning(
|
||
"ASR 转写失败 segment_index={} conversation_id={}: {}",
|
||
segment_index,
|
||
conversation_id,
|
||
e,
|
||
)
|
||
transcript_text = ""
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.TRANSCRIPT,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": transcript_text or "",
|
||
"audio_duration": audio_duration,
|
||
"voice_session_id": voice_session_id,
|
||
"segment_index": segment_index,
|
||
"is_last": is_last,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
if _is_transcribe_failure(transcript_text):
|
||
detail = (transcript_text or "").strip()
|
||
if not detail:
|
||
user_msg = f"分段 {segment_index} 未识别到语音内容,请重试或检查麦克风与网络"
|
||
else:
|
||
user_msg = f"分段 {segment_index} 语音识别失败,请稍后再试"
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": user_msg,
|
||
"segment_index": segment_index,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
return
|
||
|
||
existing_segment = await _find_existing_segment_by_index(
|
||
db=db,
|
||
conversation_id=conversation_id,
|
||
voice_session_id=voice_session_id,
|
||
segment_index=segment_index,
|
||
)
|
||
if existing_segment:
|
||
async with state.lock:
|
||
state.processed_indices.add(segment_index)
|
||
logger.debug(
|
||
"分段已存在,按幂等跳过: conversation_id={} voice_session_id={} "
|
||
"segment_index={} segment_id={} transcript={}",
|
||
conversation_id,
|
||
voice_session_id,
|
||
segment_index,
|
||
existing_segment.id,
|
||
existing_segment.user_input_text or "",
|
||
)
|
||
return
|
||
else:
|
||
segment = Segment(
|
||
id=str(uuid.uuid4()),
|
||
conversation_id=conversation_id,
|
||
user_input_text=transcript_text or "",
|
||
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
|
||
audio_duration_seconds=audio_duration
|
||
if audio_duration > 0
|
||
else None,
|
||
processed=False,
|
||
)
|
||
await persist_voice_segment_row(db, segment, conversation)
|
||
user_message_timestamp = conversation.last_message_at
|
||
await db.refresh(segment)
|
||
|
||
ready_segments: List[Tuple[int, str, Segment]] = []
|
||
tts_flag_this_voice_session = False
|
||
async with state.lock:
|
||
state.processed_indices.add(segment_index)
|
||
state.buffered_transcripts[segment_index] = (
|
||
transcript_text or "",
|
||
segment,
|
||
)
|
||
|
||
next_index = state.consumed_index + 1
|
||
while next_index in state.buffered_transcripts:
|
||
text, seg = state.buffered_transcripts.pop(next_index)
|
||
ready_segments.append((next_index, text, seg))
|
||
state.consumed_index = next_index
|
||
next_index += 1
|
||
|
||
tts_flag_this_voice_session = bool(state.tts_this_turn)
|
||
|
||
for _, ordered_text, ordered_segment in ready_segments:
|
||
await process_user_message(
|
||
conversation_id=conversation_id,
|
||
user_message=ordered_text,
|
||
conversation=conversation,
|
||
segment=ordered_segment,
|
||
db=db,
|
||
user=user,
|
||
user_message_timestamp=ordered_segment.created_at
|
||
or user_message_timestamp,
|
||
tts_this_turn=tts_flag_this_voice_session,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
|
||
exc_info=True,
|
||
)
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": "语音分段处理遇到问题,请重试",
|
||
"segment_index": segment_index,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
finally:
|
||
async with state.lock:
|
||
state.pending_indices.discard(segment_index)
|
||
|
||
|
||
# ── 用户消息处理 ────────────────────────────────────────────────
|
||
|
||
|
||
async def process_persisted_user_segment_response(
|
||
*,
|
||
conversation_id: str,
|
||
user_id: str,
|
||
segment_id: str,
|
||
tts_this_turn: bool = False,
|
||
) -> None:
|
||
"""后台继续生成已落库用户段落的助手回复;即使 WS 页面退出也要完成落库。"""
|
||
lock = _get_user_response_lock(conversation_id)
|
||
async with lock:
|
||
async with AsyncSessionLocal() as db:
|
||
conversation = await db.get(Conversation, conversation_id)
|
||
user = await db.get(User, user_id)
|
||
segment = await db.get(Segment, segment_id)
|
||
if (
|
||
not conversation
|
||
or conversation.deleted_at is not None
|
||
or conversation.user_id != user_id
|
||
or not user
|
||
or not segment
|
||
or segment.conversation_id != conversation_id
|
||
):
|
||
logger.warning(
|
||
"跳过用户回复后台任务: conversation_id={} segment_id={} user_id={}",
|
||
conversation_id,
|
||
segment_id,
|
||
user_id,
|
||
)
|
||
return
|
||
await process_user_message(
|
||
conversation_id=conversation_id,
|
||
user_message=segment.user_input_text or "",
|
||
conversation=conversation,
|
||
segment=segment,
|
||
db=db,
|
||
user=user,
|
||
user_message_timestamp=segment.created_at
|
||
or conversation.last_message_at,
|
||
tts_this_turn=tts_this_turn,
|
||
)
|
||
|
||
|
||
async def process_user_message(
|
||
conversation_id: str,
|
||
user_message: str,
|
||
conversation: Conversation,
|
||
segment: Segment,
|
||
db: AsyncSession,
|
||
user: User = None,
|
||
user_message_timestamp: Optional[datetime] = None,
|
||
*,
|
||
force_skip_tts: bool = False,
|
||
tts_this_turn: Optional[bool] = None,
|
||
memoir_trigger: MemoirTrigger = "turn",
|
||
schedule_memoir: bool = True,
|
||
) -> None:
|
||
"""处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。"""
|
||
with business_span("conversation.ws.process_turn"):
|
||
await _process_user_message_inner(
|
||
conversation_id,
|
||
user_message,
|
||
conversation,
|
||
segment,
|
||
db,
|
||
user,
|
||
user_message_timestamp,
|
||
force_skip_tts=force_skip_tts,
|
||
tts_this_turn=tts_this_turn,
|
||
memoir_trigger=memoir_trigger,
|
||
schedule_memoir=schedule_memoir,
|
||
)
|
||
|
||
|
||
async def _process_user_message_inner(
|
||
conversation_id: str,
|
||
user_message: str,
|
||
conversation: Conversation,
|
||
segment: Segment,
|
||
db: AsyncSession,
|
||
user: User = None,
|
||
user_message_timestamp: Optional[datetime] = None,
|
||
*,
|
||
force_skip_tts: bool = False,
|
||
tts_this_turn: Optional[bool] = None,
|
||
memoir_trigger: MemoirTrigger = "turn",
|
||
schedule_memoir: bool = True,
|
||
) -> None:
|
||
store = ConversationHistoryStore(db)
|
||
tts_urls: list[str] = []
|
||
user_language = _resolve_user_language(user)
|
||
try:
|
||
logger.info(
|
||
"process_user_message 开始: conversation_id={} segment_id={} user_chars={}",
|
||
conversation_id,
|
||
segment.id,
|
||
len(user_message or ""),
|
||
)
|
||
is_from_voice = bool(segment.audio_url)
|
||
voice_session_id = _voice_session_id_from_audio_url(segment.audio_url)
|
||
audio_dur = getattr(segment, "audio_duration_seconds", None)
|
||
t_pipeline = time.perf_counter()
|
||
turn = await chat_turn_service.process_turn(
|
||
ChatTurnInput(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
is_from_voice=is_from_voice,
|
||
voice_session_id=voice_session_id,
|
||
user_message_timestamp=user_message_timestamp,
|
||
audio_duration_seconds=audio_dur,
|
||
force_skip_tts=force_skip_tts,
|
||
),
|
||
ChatTurnContext(
|
||
db=db,
|
||
user=user,
|
||
conversation=conversation,
|
||
apply_extracted_profile_fn=apply_extracted_profile,
|
||
get_missing_profile_fields_fn=get_missing_profile_fields,
|
||
get_filled_profile_fields_fn=get_filled_profile_fields,
|
||
),
|
||
)
|
||
responses = turn.messages
|
||
skip_tts = bool(turn.skip_tts)
|
||
want_voice = bool(tts_this_turn) if tts_this_turn is not None else False
|
||
want_tts = want_voice and settings.enable_tts and not skip_tts
|
||
if agent_summary_enabled():
|
||
logger.info(
|
||
"pipeline.process_user_message duration_ms={:.2f} "
|
||
"conversation_id={} segment_id={} user_msg_len={} "
|
||
"response_segments={} skip_tts={} want_tts={}",
|
||
(time.perf_counter() - t_pipeline) * 1000,
|
||
conversation_id,
|
||
segment.id,
|
||
len(user_message or ""),
|
||
len(turn.messages),
|
||
turn.skip_tts,
|
||
want_tts,
|
||
)
|
||
|
||
agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses)
|
||
turn_ids = await store.record_human_ai_turn_with_segment(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
responses=responses,
|
||
segment=segment,
|
||
user_message_timestamp=user_message_timestamp,
|
||
is_from_voice=is_from_voice,
|
||
voice_session_id=voice_session_id,
|
||
audio_duration_seconds=audio_dur,
|
||
agent_response=agent_response,
|
||
memory_retrieval_trace=turn.memory_retrieval_trace,
|
||
)
|
||
if not turn_ids:
|
||
logger.warning(
|
||
"process_user_message: 无有效助手段落(responses 为空),conversation_id={} segment_id={}",
|
||
conversation_id,
|
||
segment.id,
|
||
)
|
||
if conversation_id in manager.active_connections:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {
|
||
"message": "未生成回复,请重试或稍后再试",
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
owner_id = (user.id if user is not None else None) or conversation.user_id
|
||
if schedule_memoir:
|
||
await _schedule_memoir_ingest_for_segment(
|
||
owner_id,
|
||
segment,
|
||
trigger=memoir_trigger,
|
||
)
|
||
return
|
||
|
||
owner_id = (user.id if user is not None else None) or conversation.user_id
|
||
if schedule_memoir:
|
||
await _schedule_memoir_ingest_for_segment(
|
||
owner_id,
|
||
segment,
|
||
trigger=memoir_trigger,
|
||
)
|
||
|
||
ai_msg_id = turn_ids.assistant_message_id
|
||
tts_epoch_start = _tts_epoch_value(conversation_id)
|
||
n = len(responses)
|
||
# tts_cancelled 仅用于跳过后续 TTS 合成;AGENT_RESPONSE 必须为每段完整下发,
|
||
# 否则 FE 会停留在 "正在回复…" 或丢失尾段文本。
|
||
tts_cancelled = False
|
||
for i, response_text in enumerate(responses):
|
||
url_for_segment: Optional[str] = None
|
||
if want_tts and not tts_cancelled:
|
||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||
tts_cancelled = True
|
||
logger.info(
|
||
"pipeline.process_user_message segment={}/{} tts_branch=skip_cancelled "
|
||
"tts_cancelled={} conversation_id={}",
|
||
i,
|
||
n,
|
||
tts_cancelled,
|
||
conversation_id,
|
||
)
|
||
else:
|
||
logger.info(
|
||
"pipeline.process_user_message segment={}/{} tts_branch=synthesize "
|
||
"tts_cancelled={} conversation_id={}",
|
||
i,
|
||
n,
|
||
tts_cancelled,
|
||
conversation_id,
|
||
)
|
||
url_for_segment = await _send_tts_audio(
|
||
conversation_id,
|
||
response_text,
|
||
chunk_index=i,
|
||
chunk_total=n,
|
||
assistant_message_id=ai_msg_id,
|
||
tts_epoch_start=tts_epoch_start,
|
||
language=user_language,
|
||
)
|
||
if url_for_segment:
|
||
tts_urls.append(url_for_segment)
|
||
if _tts_epoch_value(conversation_id) != tts_epoch_start:
|
||
tts_cancelled = True
|
||
else:
|
||
logger.info(
|
||
"pipeline.process_user_message segment={}/{} tts_branch={} "
|
||
"tts_cancelled={} want_tts={} conversation_id={}",
|
||
i,
|
||
n,
|
||
"skip_cancelled" if tts_cancelled else "skip_no_tts",
|
||
tts_cancelled,
|
||
want_tts,
|
||
conversation_id,
|
||
)
|
||
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.AGENT_RESPONSE,
|
||
"conversation_id": conversation_id,
|
||
"data": {
|
||
"text": response_text,
|
||
"index": i,
|
||
"total": n,
|
||
"assistant_message_id": ai_msg_id,
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
|
||
if i < n - 1:
|
||
await asyncio.sleep(0.5)
|
||
|
||
if user is not None:
|
||
try:
|
||
fresh_memoir = await get_or_create_state(user.id, db)
|
||
await maybe_send_topic_chips_ws(
|
||
conversation_id,
|
||
user=user,
|
||
memoir_state=fresh_memoir,
|
||
reason="after_assistant_turn",
|
||
language=user_language,
|
||
)
|
||
except Exception as chip_err:
|
||
logger.warning("after-turn topic chips skipped: {}", chip_err)
|
||
|
||
if tts_urls:
|
||
await store.attach_ai_tts_for_turn(
|
||
conversation_id,
|
||
tts_audio_urls=tts_urls,
|
||
segment=segment,
|
||
)
|
||
|
||
except Exception as e:
|
||
if tts_urls:
|
||
try:
|
||
await store.attach_ai_tts_for_turn(
|
||
conversation_id,
|
||
tts_audio_urls=tts_urls,
|
||
segment=segment,
|
||
)
|
||
except Exception as persist_error:
|
||
logger.warning("补写 TTS 元数据失败: {}", persist_error)
|
||
logger.exception("处理用户消息失败: {}", e)
|
||
if conversation_id in manager.active_connections:
|
||
try:
|
||
await manager.send_message(
|
||
conversation_id,
|
||
{
|
||
"type": MessageType.ERROR,
|
||
"data": {"message": "生成回应时遇到问题,请稍后再试"},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
},
|
||
)
|
||
except Exception as send_error:
|
||
logger.warning("发送错误消息失败: {}", send_error)
|
||
|
||
|
||
# ── 对话结束处理 ────────────────────────────────────────────────
|
||
|
||
|
||
async def process_conversation_segments(
|
||
conversation_id: str, db: AsyncSession, quota_service: "QuotaService"
|
||
):
|
||
"""
|
||
对话结束时:把本对话仍待 Phase1 的段落交给回忆录管线。
|
||
|
||
经 `MemoirIngestScheduler.flush_pending` 将内存防抖 batch 与当前查询到的
|
||
`topic_category IS NULL` 段 ID 合并、去重后**单次**提交 `process_memoir_phase1`,
|
||
并在 flush 末尾触发待叙事 Phase2 派发;避免会话结束路径与 debounce flush 双发 Phase1。
|
||
|
||
配额检查通过注入的 `quota_service` 完成,不直接 import quota 内部函数。
|
||
"""
|
||
conversation = await db.get(Conversation, conversation_id)
|
||
if not conversation or conversation.deleted_at is not None:
|
||
return
|
||
|
||
stmt = select(Segment).where(
|
||
Segment.conversation_id == conversation_id,
|
||
Segment.processed == False,
|
||
Segment.topic_category.is_(None),
|
||
)
|
||
result = await db.execute(stmt)
|
||
segments = result.scalars().all()
|
||
|
||
if not segments:
|
||
await memoir_ingest_scheduler.flush_pending(
|
||
conversation.user_id,
|
||
trigger="conversation_end",
|
||
)
|
||
return
|
||
|
||
user = await db.get(User, conversation.user_id)
|
||
if user:
|
||
can_submit, _ = await quota_service.check_can_submit_organize(
|
||
user.id, user.subscription_type
|
||
)
|
||
if not can_submit:
|
||
logger.info(
|
||
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
|
||
)
|
||
await memoir_ingest_scheduler.flush_pending(
|
||
conversation.user_id,
|
||
trigger="conversation_end",
|
||
)
|
||
return
|
||
|
||
segment_ids = [seg.id for seg in segments]
|
||
try:
|
||
await memoir_ingest_scheduler.flush_pending(
|
||
conversation.user_id,
|
||
extra_segment_ids=segment_ids,
|
||
trigger="conversation_end",
|
||
)
|
||
logger.info(
|
||
"对话结束,合并批内 segment 与 DB 待分类段,单次提交 Phase1: "
|
||
"conversation_id={} segments={}",
|
||
conversation_id,
|
||
len(segment_ids),
|
||
)
|
||
except Exception as e:
|
||
logger.error("提交 Celery 任务失败: {}", e)
|