Files
life-echo/api/app/features/conversation/ws/pipeline.py
Kevin 93be60f74c fix(tts): gate auto reply by ENABLE_TTS; allow on-demand and manual playback
- Pipeline: skip _send_tts_audio only for non-manual when ENABLE_TTS=false;
  remove enable_tts early return from handle_tts_request_on_demand.
- Tencent TTS: PrimaryLanguage/chunking follow user language preference only.
- Expo: let manual tts_audio bypass late-segment playback gate after interrupt.
- Docs: clarify ENABLE_TTS vs tts_request in api/.env.example and TTSProvider port.
- Tests: add manual bypass cases; adjust pipeline language tests for en+Chinese text.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-11 17:15:02 +08:00

1334 lines
49 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""核心消息处理管道Agent 调用、ASR 转写、分段有序聚合"""
import asyncio
import base64
import io
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from app.core.logging import get_logger
if TYPE_CHECKING:
from app.features.quota.service import QuotaService
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat import ChatOrchestrator
from app.agents.chat.reply_limits import segments_from_llm_response
from app.core.agent_logging import agent_summary_enabled
from app.core.config import settings
from app.core.cos_url_keys import (
TTS_PRESIGNED_EXPIRES_SEC,
extract_cos_object_key_if_owned,
)
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
from app.features.conversation.chat_turn import (
ChatTurnContext,
ChatTurnInput,
ChatTurnService,
)
from app.features.conversation.history_store import (
AI_RESPONSE_SEGMENT_JOIN,
ConversationHistoryStore,
)
from app.features.conversation.lineage_schemas import DialogueLineage
from app.features.conversation.models import Conversation, ConversationMessage, Segment
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import MessageType
from app.features.conversation.ws.profile_collector import (
apply_extracted_profile,
get_filled_profile_fields,
get_missing_profile_fields,
)
from app.features.memoir.background_runner import BackgroundTaskRunner
from app.features.memoir.ingest_scheduler import MemoirIngestScheduler
from app.features.user.models import User
from app.ports.asr import ASRTranscriptionError
logger = get_logger(__name__)
# 客户端发送 tts_cancel 时递增process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段
_tts_cancel_epoch: dict[str, int] = {}
def bump_tts_cancel_epoch(conversation_id: str) -> None:
_tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1
def _tts_epoch_value(conversation_id: str) -> int:
return _tts_cancel_epoch.get(conversation_id, 0)
def _resolve_user_language(user) -> str:
"""Return 'en' iff user.language_preference is set to 'en'; default 'zh'."""
raw = getattr(user, "language_preference", "zh") if user is not None else "zh"
return "en" if str(raw or "zh").strip().lower() == "en" else "zh"
def _tts_object_ext(codec: str) -> str:
c = (codec or "mp3").lower().lstrip(".")
if c in ("wave",):
return "wav"
return c if c else "mp3"
def _tts_codec_to_content_type(codec: str) -> str:
c = (codec or "mp3").lower().lstrip(".")
if c == "mp3":
return "audio/mpeg"
if c in ("wav", "wave"):
return "audio/wav"
return "application/octet-stream"
async def _send_tts_audio(
conversation_id: str,
text: str,
*,
chunk_index: int,
chunk_total: int,
assistant_message_id: str | None,
tts_epoch_start: int,
manual: bool = False,
language: str = "zh",
) -> str | None:
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
current_epoch = _tts_epoch_value(conversation_id)
# 长期保留 INFOTTS 决策与执行链路必须在 INFO 级别全程可见
logger.info(
"pipeline._send_tts_audio entry conversation_id={} chunk_index={} chunk_total={} "
"text_len={} language={} manual={} tts_epoch_start={} current_epoch={} "
"enable_tts={} provider={}",
conversation_id,
chunk_index,
chunk_total,
len(text or ""),
language,
manual,
tts_epoch_start,
current_epoch,
settings.enable_tts,
settings.tts_provider,
)
# enable_tts仅禁用「助手回复自动生成 TTS」want_tts 路径用户点喇叭manual=True仍可合成。
if not manual and not settings.enable_tts:
logger.info(
"pipeline._send_tts_audio result conversation_id={} chunk_index={} ok=False "
"url_set=False audio_bytes_len=0 reason=enable_tts_false",
conversation_id,
chunk_index,
)
return None
if current_epoch != tts_epoch_start:
logger.info(
"pipeline._send_tts_audio result conversation_id={} chunk_index={} ok=False "
"url_set=False audio_bytes_len=0 reason=epoch_mismatch_pre_synth "
"tts_epoch_start={} current_epoch={}",
conversation_id,
chunk_index,
tts_epoch_start,
current_epoch,
)
return None
try:
tts = get_tts_provider()
audio_bytes = await tts.synthesize(text, language=language)
if not audio_bytes:
logger.warning(
"TTS skipped: synthesize returned empty conversation_id={} chunk_index={} "
"language={} text_preview={!r} voice_provider={}",
conversation_id,
chunk_index,
language,
(text or "")[:30],
settings.tts_provider,
)
logger.info(
"pipeline._send_tts_audio result conversation_id={} chunk_index={} ok=False "
"url_set=False audio_bytes_len=0 reason=synthesize_empty",
conversation_id,
chunk_index,
)
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
logger.info(
"pipeline._send_tts_audio result conversation_id={} chunk_index={} ok=False "
"url_set=False audio_bytes_len={} reason=epoch_mismatch_post_synth",
conversation_id,
chunk_index,
len(audio_bytes),
)
return None
ext = _tts_object_ext(settings.tts_codec)
content_type = _tts_codec_to_content_type(settings.tts_codec)
storage = get_object_storage()
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
upload_started = time.perf_counter()
logger.debug(
"pipeline._send_tts_audio uploading key={} audio_bytes_len={} content_type={}",
key,
len(audio_bytes),
content_type,
)
public_url = storage.upload(key, audio_bytes, content_type)
upload_ms = (time.perf_counter() - upload_started) * 1000
# 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL
playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
logger.debug(
"pipeline._send_tts_audio uploaded key={} audio_bytes_len={} upload_ms={:.2f} "
"public_url_set={} playback_url_set={}",
key,
len(audio_bytes),
upload_ms,
bool(public_url),
bool(playback_url),
)
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
payload_data: Dict[str, Any] = {
"audio_base64": audio_b64,
"format": settings.tts_codec,
"audio_url": playback_url,
"index": chunk_index,
"total": chunk_total,
}
if assistant_message_id:
payload_data["assistant_message_id"] = assistant_message_id
if manual:
payload_data["manual"] = True
logger.debug(
"pipeline._send_tts_audio sending TTS_AUDIO conversation_id={} chunk_index={} "
"chunk_total={} payload_fields={} audio_b64_len={} manual={}",
conversation_id,
chunk_index,
chunk_total,
sorted(payload_data.keys()),
len(audio_b64),
manual,
)
await manager.send_message(
conversation_id,
{
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": payload_data,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
logger.info(
"pipeline._send_tts_audio result conversation_id={} chunk_index={} ok=True "
"url_set={} audio_bytes_len={} upload_ms={:.2f} manual={}",
conversation_id,
chunk_index,
bool(public_url),
len(audio_bytes),
upload_ms,
manual,
)
return public_url
except Exception as e:
err_str = str(e)
if "PkgExhausted" in err_str:
logger.warning(
"TTS skipped: 腾讯云语音合成资源包已用尽,请在控制台购买或开通后付费: {}",
err_str[:100],
)
else:
logger.error("TTS synthesize failed: {}", e)
logger.info(
"pipeline._send_tts_audio result conversation_id={} chunk_index={} ok=False "
"url_set=False audio_bytes_len=0 reason=exception err={}",
conversation_id,
chunk_index,
type(e).__name__,
)
return None
async def handle_tts_request_on_demand(
*,
conversation_id: str,
user_id: str,
assistant_message_id: str,
segment_index: int,
segment_text: str | None,
db: AsyncSession,
) -> tuple[bool, str]:
"""用户点喇叭:该段已有 TTS 则预签名下发;否则合成后落库并下发。不重复合成同一段。"""
logger.info(
"pipeline.handle_tts_request_on_demand entry conversation_id={} user_id={} "
"assistant_message_id={} segment_index={} segment_text_len={} enable_tts={} provider={}",
conversation_id,
user_id,
assistant_message_id,
segment_index,
len(segment_text or ""),
settings.enable_tts,
settings.tts_provider,
)
conv = await db.get(Conversation, conversation_id)
if not conv or conv.user_id != user_id or conv.deleted_at is not None:
logger.debug(
"pipeline.handle_tts_request_on_demand result ok=False reason=对话不存在或无权访问 "
"conversation_id={} user_id={}",
conversation_id,
user_id,
)
return False, "对话不存在或无权访问"
msg = await db.get(ConversationMessage, assistant_message_id)
if not msg or msg.conversation_id != conversation_id or msg.role != "ai":
logger.debug(
"pipeline.handle_tts_request_on_demand result ok=False reason=消息不存在 "
"conversation_id={} assistant_message_id={}",
conversation_id,
assistant_message_id,
)
return False, "消息不存在"
# 与客户端 splitMessageParts / segments_from_llm_response 对齐(含无 [SPLIT] 时的段落拆段)
parts = segments_from_llm_response(msg.content or "", max_segments=3)
if segment_index < 0 or segment_index >= len(parts):
return False, "分段序号无效"
canon = (parts[segment_index] or "").strip()
if not canon:
return False, "该段无朗读文本"
if segment_text and segment_text.strip() and segment_text.strip() != canon:
logger.debug(
"按需 TTS: 客户端传入 segment_text 与规范化后 canon 不完全一致,已按 segment_index 朗读 canon "
"(client_len={} canon_len={})",
len(segment_text.strip()),
len(canon),
)
urls: List[str] = []
for x in msg.tts_audio_urls or []:
if isinstance(x, str) and x.strip():
urls.append(x)
else:
urls.append("")
while len(urls) < len(parts):
urls.append("")
existing = urls[segment_index].strip() if segment_index < len(urls) else ""
chunk_total = len(parts)
if existing:
logger.info(
"pipeline.handle_tts_request_on_demand reuse existing url conversation_id={} "
"assistant_message_id={} segment_index={} url_len={}",
conversation_id,
assistant_message_id,
segment_index,
len(existing),
)
storage = get_object_storage()
key = extract_cos_object_key_if_owned(existing)
try:
playback_url = (
storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
if key
else existing
)
except Exception as exc:
logger.warning("按需 TTS 预签名失败: {}", exc)
playback_url = existing
await manager.send_message(
conversation_id,
{
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {
"audio_url": playback_url,
"format": settings.tts_codec,
"index": segment_index,
"total": chunk_total,
"assistant_message_id": assistant_message_id,
"manual": True,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
logger.info(
"pipeline.handle_tts_request_on_demand result ok=True reason=existing_reused "
"conversation_id={} assistant_message_id={} segment_index={}",
conversation_id,
assistant_message_id,
segment_index,
)
return True, ""
logger.info(
"pipeline.handle_tts_request_on_demand no existing url, will synthesize "
"conversation_id={} assistant_message_id={} segment_index={} canon_len={}",
conversation_id,
assistant_message_id,
segment_index,
len(canon),
)
user_obj = await db.get(User, user_id)
user_language = _resolve_user_language(user_obj)
tts_epoch_start = _tts_epoch_value(conversation_id)
url_stored = await _send_tts_audio(
conversation_id,
canon,
chunk_index=segment_index,
chunk_total=chunk_total,
assistant_message_id=assistant_message_id,
tts_epoch_start=tts_epoch_start,
manual=True,
language=user_language,
)
logger.info(
"pipeline.handle_tts_request_on_demand _send_tts_audio returned url_stored_set={} "
"conversation_id={} assistant_message_id={} segment_index={}",
bool(url_stored),
conversation_id,
assistant_message_id,
segment_index,
)
if not url_stored:
logger.info(
"pipeline.handle_tts_request_on_demand result ok=False reason=语音合成失败 "
"conversation_id={} assistant_message_id={} segment_index={}",
conversation_id,
assistant_message_id,
segment_index,
)
return False, "语音合成失败"
while len(urls) <= segment_index:
urls.append("")
urls[segment_index] = url_stored
msg.tts_audio_urls = urls
await db.commit()
store = ConversationHistoryStore(db)
await store._sync_redis_best_effort(conversation_id)
logger.info(
"pipeline.handle_tts_request_on_demand result ok=True reason=synthesized "
"conversation_id={} assistant_message_id={} segment_index={}",
conversation_id,
assistant_message_id,
segment_index,
)
return True, ""
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
chat_orchestrator = ChatOrchestrator()
chat_turn_service = ChatTurnService(chat_orchestrator)
_background_runner = BackgroundTaskRunner()
memoir_ingest_scheduler = MemoirIngestScheduler(_background_runner)
# ── 分段流状态 ──────────────────────────────────────────────────
@dataclass
class SegmentStreamState:
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
#: 本条语音会话最近一次分段上行携带的本轮朗读开关(客户端每段一致即可)
tts_this_turn: bool = False
pending_indices: Set[int] = field(default_factory=set)
processed_indices: Set[int] = field(default_factory=set)
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
consumed_index: int = -1
active_tasks: Set[asyncio.Task] = field(default_factory=set)
listening_feedback_sent: bool = False
listening_feedback_task: Optional[asyncio.Task] = None
_segment_states: Dict[Tuple[str, str], SegmentStreamState] = {}
_user_response_tasks: Dict[str, Set[asyncio.Task]] = {}
_user_response_locks: Dict[str, asyncio.Lock] = {}
def _get_user_response_lock(conversation_id: str) -> asyncio.Lock:
lock = _user_response_locks.get(conversation_id)
if lock is None:
lock = asyncio.Lock()
_user_response_locks[conversation_id] = lock
return lock
def register_user_response_task(conversation_id: str, task: asyncio.Task) -> None:
tasks = _user_response_tasks.setdefault(conversation_id, set())
tasks.add(task)
def _cleanup(done_task: asyncio.Task) -> None:
tasks.discard(done_task)
if not tasks:
_user_response_tasks.pop(conversation_id, None)
_user_response_locks.pop(conversation_id, None)
if done_task.cancelled():
logger.warning(
"用户回复后台任务被取消 conversation_id={}",
conversation_id,
)
return
exc = done_task.exception()
if exc:
logger.error(
"用户回复后台任务异常 conversation_id={}: {}",
conversation_id,
exc,
exc_info=True,
)
task.add_done_callback(_cleanup)
def get_or_create_segment_state(
conversation_id: str,
voice_session_id: str,
) -> SegmentStreamState:
state_key = (conversation_id, voice_session_id)
if state_key not in _segment_states:
_segment_states[state_key] = SegmentStreamState()
return _segment_states[state_key]
def register_segment_task(
conversation_id: str,
voice_session_id: str,
task: asyncio.Task,
) -> None:
state_key = (conversation_id, voice_session_id)
state = get_or_create_segment_state(conversation_id, voice_session_id)
state.active_tasks.add(task)
def _cleanup(done_task: asyncio.Task) -> None:
state.active_tasks.discard(done_task)
if not state.active_tasks and conversation_id not in manager.active_connections:
_segment_states.pop(state_key, None)
if done_task.cancelled():
return
exc = done_task.exception()
if exc:
logger.error(
"分段处理任务异常 "
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
exc_info=True,
)
task.add_done_callback(_cleanup)
def cleanup_segment_states(conversation_id: str) -> None:
"""断开连接后清理无活跃任务的分段状态"""
stale_keys = [
key
for key, state in _segment_states.items()
if key[0] == conversation_id and not state.active_tasks
]
for key in stale_keys:
_segment_states.pop(key, None)
# ── 工具函数 ────────────────────────────────────────────────────
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _mark_conversation_active(
conversation: Conversation, at: Optional[datetime] = None
) -> datetime:
activity_time = at or _utc_now()
conversation.last_message_at = activity_time
return activity_time
def _voice_session_id_from_client_segment_id(
client_segment_id: Optional[str],
) -> Optional[str]:
if not client_segment_id:
return None
session_id, separator, _ = client_segment_id.rpartition("-")
if separator and session_id:
return session_id
return None
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
"""构建分段语音的幂等标识conversation_id + voice_session_id + segment_index"""
return f"audio-segment:{voice_session_id}:{segment_index}"
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
"""从 audio_url 解析 voice_session_id 与 segment_indexaudio-segment:{session_id}:{index})。"""
prefix = "audio-segment:"
if not audio_url or not audio_url.startswith(prefix):
return None
payload = audio_url[len(prefix) :]
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
if not separator:
return None
try:
sid = str(voice_session_id_raw).strip()
if not sid:
return None
return (sid, int(segment_index_raw))
except ValueError:
return None
def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]:
scope = _extract_segment_scope(audio_url)
if scope:
return scope[0]
return None
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
if not transcript_text:
return True
return transcript_text.startswith("转写失败")
async def _find_existing_segment_by_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
segment_index: int,
) -> Optional[Segment]:
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
stmt = (
select(Segment)
.where(
Segment.conversation_id == conversation_id,
Segment.audio_url == segment_audio_url,
)
.order_by(Segment.created_at.desc())
)
result = await db.execute(stmt)
candidates = result.scalars().all()
for item in candidates:
if (
item.conversation_id == conversation_id
and item.audio_url == segment_audio_url
):
return item
return None
async def _get_persisted_contiguous_segment_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
) -> int:
"""读取数据库中当前 voice session 已连续落库的最大 segment_index用于重连恢复。"""
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
result = await db.execute(stmt)
candidates = result.scalars().all()
persisted_indices: Set[int] = set()
for item in candidates:
if item.conversation_id != conversation_id:
continue
segment_scope = _extract_segment_scope(item.audio_url)
if not segment_scope:
continue
item_voice_session_id, item_index = segment_scope
if item_voice_session_id != voice_session_id:
continue
persisted_indices.add(item_index)
contiguous_index = -1
while contiguous_index + 1 in persisted_indices:
contiguous_index += 1
return contiguous_index
# ── 过渡反馈 ────────────────────────────────────────────────────
LISTENING_FEEDBACK_DELAY_SEC = 5.0
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
async def _send_segment_transition_feedback(
conversation_id: str,
segment_index: int,
) -> None:
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": LISTENING_FEEDBACK_TEXT,
"transition": True,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
async def _delayed_listening_feedback(
conversation_id: str,
voice_session_id: str,
) -> None:
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
state = get_or_create_segment_state(conversation_id, voice_session_id)
async with state.lock:
if state.listening_feedback_sent:
return
state.listening_feedback_sent = True
state.listening_feedback_task = None
await _send_segment_transition_feedback(conversation_id, 0)
# ── 长音频切片转写 ────────────────────────────────────────────
MAX_ASR_CHUNK_MS = 55_000
def _split_audio_bytes(audio_bytes: bytes, fmt: str) -> list[bytes]:
"""用 pydub 将长音频按 ≤55 s 切片,每片导出为 16 kHz mono WAV腾讯 ASR 3 MB 限制内)。"""
from pydub import AudioSegment as PydubSegment
audio = PydubSegment.from_file(io.BytesIO(audio_bytes), format=fmt)
duration_ms = len(audio)
if duration_ms <= MAX_ASR_CHUNK_MS:
return [audio_bytes]
mono_16k = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)
chunks: list[bytes] = []
for start in range(0, duration_ms, MAX_ASR_CHUNK_MS):
chunk = mono_16k[start : start + MAX_ASR_CHUNK_MS]
buf = io.BytesIO()
chunk.export(buf, format="wav")
chunks.append(buf.getvalue())
return chunks
async def _transcribe_long_audio(audio_bytes: bytes, fmt: str = "m4a") -> str:
"""超过 55 s 的音频自动切片后并行 ASR短音频直接转写。"""
asr = get_asr_provider()
try:
chunks = await asyncio.to_thread(_split_audio_bytes, audio_bytes, fmt)
except Exception as exc:
logger.warning("pydub 切片失败 ({}), 回退到直接转写", exc)
return await asr.transcribe(audio_bytes, format=fmt)
if len(chunks) <= 1:
return await asr.transcribe(audio_bytes, format=fmt)
logger.info("长音频切片: {}", len(chunks))
results = await asyncio.gather(
*[asr.transcribe(c, format="wav") for c in chunks],
return_exceptions=True,
)
texts: list[str] = []
for i, r in enumerate(results):
if isinstance(r, BaseException):
logger.warning("切片 {} 转写异常: {}", i, r)
continue
if r and not _is_transcribe_failure(r):
texts.append(r)
return "".join(texts)
# ── 分段语音异步处理 ────────────────────────────────────────────
async def process_audio_segment(
conversation_id: str,
user_id: str,
voice_session_id: str,
segment_index: int,
audio_base64: str,
audio_duration: int,
is_last: bool,
*,
tts_this_turn: bool = False,
) -> None:
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
state = get_or_create_segment_state(conversation_id, voice_session_id)
async with state.lock:
state.tts_this_turn = bool(tts_this_turn)
logger.info(
"process_audio_segment 开始: conversation_id={} voice_session_id={} "
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
conversation_id,
voice_session_id,
segment_index,
is_last,
audio_duration,
len(audio_base64 or ""),
)
try:
async with AsyncSessionLocal() as db:
conversation = await db.get(Conversation, conversation_id)
user = await db.get(User, user_id)
if not conversation or conversation.deleted_at is not None:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "对话不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
if not user:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "用户不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
async with state.lock:
should_prime_state = (
state.consumed_index < 0
and not state.processed_indices
and not state.buffered_transcripts
)
if should_prime_state:
persisted_contiguous_index = (
await _get_persisted_contiguous_segment_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
)
if persisted_contiguous_index >= 0:
async with state.lock:
state.consumed_index = max(
state.consumed_index, persisted_contiguous_index
)
try:
audio_bytes = base64.b64decode(audio_base64)
except Exception:
audio_bytes = b""
if not audio_bytes:
logger.warning(
"process_audio_segment: 解码后音频为空 conversation_id={} segment_index={}",
conversation_id,
segment_index,
)
try:
transcript_text = await _transcribe_long_audio(audio_bytes, fmt="m4a")
except ASRTranscriptionError as e:
logger.warning(
"ASR 转写失败 segment_index={} conversation_id={}: {}",
segment_index,
conversation_id,
e,
)
transcript_text = ""
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text or "",
"audio_duration": audio_duration,
"voice_session_id": voice_session_id,
"segment_index": segment_index,
"is_last": is_last,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
if _is_transcribe_failure(transcript_text):
detail = (transcript_text or "").strip()
if not detail:
user_msg = f"分段 {segment_index} 未识别到语音内容,请重试或检查麦克风与网络"
else:
user_msg = f"分段 {segment_index} 语音识别失败,请稍后再试"
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": user_msg,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
existing_segment = await _find_existing_segment_by_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
)
if existing_segment:
async with state.lock:
state.processed_indices.add(segment_index)
logger.debug(
"分段已存在,按幂等跳过: conversation_id={} voice_session_id={} "
"segment_index={} segment_id={} transcript={}",
conversation_id,
voice_session_id,
segment_index,
existing_segment.id,
existing_segment.user_input_text or "",
)
return
else:
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
user_input_text=transcript_text or "",
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
audio_duration_seconds=audio_duration
if audio_duration > 0
else None,
processed=False,
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(conversation)
await db.commit()
await db.refresh(segment)
await memoir_ingest_scheduler.queue_segment(
conversation.user_id,
segment.id,
text_char_count=len((transcript_text or "").strip()),
)
ready_segments: List[Tuple[int, str, Segment]] = []
tts_flag_this_voice_session = False
async with state.lock:
state.processed_indices.add(segment_index)
state.buffered_transcripts[segment_index] = (
transcript_text or "",
segment,
)
next_index = state.consumed_index + 1
while next_index in state.buffered_transcripts:
text, seg = state.buffered_transcripts.pop(next_index)
ready_segments.append((next_index, text, seg))
state.consumed_index = next_index
next_index += 1
tts_flag_this_voice_session = bool(state.tts_this_turn)
for _, ordered_text, ordered_segment in ready_segments:
await process_user_message(
conversation_id=conversation_id,
user_message=ordered_text,
conversation=conversation,
segment=ordered_segment,
db=db,
user=user,
user_message_timestamp=ordered_segment.created_at
or user_message_timestamp,
tts_this_turn=tts_flag_this_voice_session,
)
except Exception as e:
logger.error(
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
exc_info=True,
)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音分段处理遇到问题,请重试",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
finally:
async with state.lock:
state.pending_indices.discard(segment_index)
# ── 用户消息处理 ────────────────────────────────────────────────
async def process_persisted_user_segment_response(
*,
conversation_id: str,
user_id: str,
segment_id: str,
tts_this_turn: bool = False,
) -> None:
"""后台继续生成已落库用户段落的助手回复;即使 WS 页面退出也要完成落库。"""
lock = _get_user_response_lock(conversation_id)
async with lock:
async with AsyncSessionLocal() as db:
conversation = await db.get(Conversation, conversation_id)
user = await db.get(User, user_id)
segment = await db.get(Segment, segment_id)
if (
not conversation
or conversation.deleted_at is not None
or conversation.user_id != user_id
or not user
or not segment
or segment.conversation_id != conversation_id
):
logger.warning(
"跳过用户回复后台任务: conversation_id={} segment_id={} user_id={}",
conversation_id,
segment_id,
user_id,
)
return
await process_user_message(
conversation_id=conversation_id,
user_message=segment.user_input_text or "",
conversation=conversation,
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at
or conversation.last_message_at,
tts_this_turn=tts_this_turn,
)
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
user: User = None,
user_message_timestamp: Optional[datetime] = None,
*,
force_skip_tts: bool = False,
tts_this_turn: Optional[bool] = None,
) -> None:
"""处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。"""
store = ConversationHistoryStore(db)
tts_urls: list[str] = []
user_language = _resolve_user_language(user)
try:
logger.info(
"process_user_message 开始: conversation_id={} segment_id={} user_chars={}",
conversation_id,
segment.id,
len(user_message or ""),
)
# 长期保留TTS 决策入口pipeline 层INFO 级别可见所有控制位
logger.info(
"pipeline.process_user_message entry conversation_id={} segment_id={} "
"tts_this_turn={} force_skip_tts={} enable_tts={} provider={} user_language={}",
conversation_id,
segment.id,
tts_this_turn,
force_skip_tts,
settings.enable_tts,
settings.tts_provider,
user_language,
)
is_from_voice = bool(segment.audio_url)
voice_session_id = _voice_session_id_from_audio_url(segment.audio_url)
audio_dur = getattr(segment, "audio_duration_seconds", None)
t_pipeline = time.perf_counter()
turn = await chat_turn_service.process_turn(
ChatTurnInput(
conversation_id=conversation_id,
user_message=user_message,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
audio_duration_seconds=audio_dur,
force_skip_tts=force_skip_tts,
),
ChatTurnContext(
db=db,
user=user,
conversation=conversation,
apply_extracted_profile_fn=apply_extracted_profile,
get_missing_profile_fields_fn=get_missing_profile_fields,
get_filled_profile_fields_fn=get_filled_profile_fields,
),
)
responses = turn.messages
skip_tts = bool(turn.skip_tts)
want_voice = bool(tts_this_turn) if tts_this_turn is not None else False
want_tts = want_voice and settings.enable_tts and not skip_tts
# 长期保留 INFOTTS 决策最终结论;不再被 agent_summary_enabled 门控
logger.info(
"pipeline.process_user_message tts_decision conversation_id={} segment_id={} "
"tts_this_turn={} force_skip_tts={} enable_tts={} skip_tts_from_turn={} "
"want_voice={} want_tts={} response_segments={}",
conversation_id,
segment.id,
tts_this_turn,
force_skip_tts,
settings.enable_tts,
skip_tts,
want_voice,
want_tts,
len(turn.messages),
)
if agent_summary_enabled():
logger.info(
"pipeline.process_user_message duration_ms={:.2f} "
"conversation_id={} segment_id={} user_msg_len={} "
"response_segments={} skip_tts={} want_tts={}",
(time.perf_counter() - t_pipeline) * 1000,
conversation_id,
segment.id,
len(user_message or ""),
len(turn.messages),
turn.skip_tts,
want_tts,
)
segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses)
_mark_conversation_active(conversation)
turn_ids = await store.record_human_ai_turn(
conversation_id=conversation_id,
user_message=user_message,
responses=responses,
user_message_timestamp=user_message_timestamp,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
audio_duration_seconds=audio_dur,
tts_audio_urls=None,
segment_id=segment.id,
memory_retrieval_trace=turn.memory_retrieval_trace,
)
if not turn_ids:
logger.warning(
"process_user_message: 无有效助手段落responses 为空conversation_id={} segment_id={}",
conversation_id,
segment.id,
)
if conversation_id in manager.active_connections:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "未生成回复,请重试或稍后再试",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
lineage = DialogueLineage.for_single_turn(
conversation_id=conversation_id,
user_message_id=turn_ids.human_message_id,
assistant_message_id=turn_ids.assistant_message_id,
segment_ids=[str(segment.id)],
)
await db.execute(
update(Segment)
.where(Segment.id == segment.id)
.values(
user_message_id=turn_ids.human_message_id,
lineage_json=lineage.model_dump(mode="json"),
)
)
await db.commit()
ai_msg_id = turn_ids.assistant_message_id
tts_epoch_start = _tts_epoch_value(conversation_id)
n = len(responses)
# tts_cancelled 仅用于跳过后续 TTS 合成AGENT_RESPONSE 必须为每段完整下发,
# 否则 FE 会停留在 "正在回复…" 或丢失尾段文本。
tts_cancelled = False
for i, response_text in enumerate(responses):
url_for_segment: Optional[str] = None
if want_tts and not tts_cancelled:
if _tts_epoch_value(conversation_id) != tts_epoch_start:
tts_cancelled = True
logger.info(
"pipeline.process_user_message segment={}/{} tts_branch=skip_cancelled "
"tts_cancelled={} conversation_id={}",
i,
n,
tts_cancelled,
conversation_id,
)
else:
logger.info(
"pipeline.process_user_message segment={}/{} tts_branch=synthesize "
"tts_cancelled={} conversation_id={}",
i,
n,
tts_cancelled,
conversation_id,
)
url_for_segment = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
assistant_message_id=ai_msg_id,
tts_epoch_start=tts_epoch_start,
language=user_language,
)
if url_for_segment:
tts_urls.append(url_for_segment)
if _tts_epoch_value(conversation_id) != tts_epoch_start:
tts_cancelled = True
else:
logger.info(
"pipeline.process_user_message segment={}/{} tts_branch={} "
"tts_cancelled={} want_tts={} conversation_id={}",
i,
n,
"skip_cancelled" if tts_cancelled else "skip_no_tts",
tts_cancelled,
want_tts,
conversation_id,
)
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": response_text,
"index": i,
"total": n,
"assistant_message_id": ai_msg_id,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
if i < n - 1:
await asyncio.sleep(0.5)
if tts_urls:
await store.attach_ai_tts_audio_urls(
conversation_id,
tts_audio_urls=tts_urls,
segment_id=segment.id,
)
await db.execute(
update(Segment)
.where(Segment.id == segment.id)
.values(tts_audio_urls=tts_urls)
)
await db.commit()
except Exception as e:
if tts_urls:
try:
await store.attach_ai_tts_audio_urls(
conversation_id,
tts_audio_urls=tts_urls,
segment_id=segment.id,
)
await db.execute(
update(Segment)
.where(Segment.id == segment.id)
.values(tts_audio_urls=tts_urls)
)
await db.commit()
except Exception as persist_error:
logger.warning("补写 TTS 元数据失败: {}", persist_error)
logger.exception("处理用户消息失败: {}", e)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "生成回应时遇到问题,请稍后再试"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as send_error:
logger.warning("发送错误消息失败: {}", send_error)
# ── 对话结束处理 ────────────────────────────────────────────────
async def process_conversation_segments(
conversation_id: str, db: AsyncSession, quota_service: "QuotaService"
):
"""
对话结束时:把本对话仍待 Phase1 的段落交给回忆录管线。
经 `MemoirIngestScheduler.flush_pending` 将内存防抖 batch 与当前查询到的
`topic_category IS NULL` 段 ID 合并、去重后**单次**提交 `process_memoir_phase1`
并在 flush 末尾触发待叙事 Phase2 派发;避免会话结束路径与 debounce flush 双发 Phase1。
配额检查通过注入的 `quota_service` 完成,不直接 import quota 内部函数。
"""
conversation = await db.get(Conversation, conversation_id)
if not conversation or conversation.deleted_at is not None:
return
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False,
Segment.topic_category.is_(None),
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
await memoir_ingest_scheduler.flush_pending(
conversation.user_id,
trigger="conversation_end",
)
return
user = await db.get(User, conversation.user_id)
if user:
can_submit, _ = await quota_service.check_can_submit_organize(
user.id, user.subscription_type
)
if not can_submit:
logger.info(
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
)
await memoir_ingest_scheduler.flush_pending(
conversation.user_id,
trigger="conversation_end",
)
return
segment_ids = [seg.id for seg in segments]
try:
await memoir_ingest_scheduler.flush_pending(
conversation.user_id,
extra_segment_ids=segment_ids,
trigger="conversation_end",
)
logger.info(
"对话结束,合并批内 segment 与 DB 待分类段,单次提交 Phase1: "
"conversation_id={} segments={}",
conversation_id,
len(segment_ids),
)
except Exception as e:
logger.error("提交 Celery 任务失败: {}", e)