2026-03-18 17:18:23 +08:00
|
|
|
|
"""核心消息处理管道:Agent 调用、ASR 转写、分段有序聚合"""
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
import asyncio
|
|
|
|
|
|
import base64
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
|
|
|
|
|
|
2026-03-20 15:15:35 +08:00
|
|
|
|
from app.core.logging import get_logger
|
|
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
|
from app.features.quota.service import QuotaService
|
|
|
|
|
|
|
2026-03-20 16:36:42 +08:00
|
|
|
|
from sqlalchemy import select, update
|
2026-03-18 17:18:23 +08:00
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
|
|
from app.agents import ConversationAgent, MemoryAgent
|
2026-03-19 10:36:55 +08:00
|
|
|
|
from app.agents.chat import ChatOrchestrator
|
|
|
|
|
|
from app.agents.memoir import BackgroundTaskRunner
|
2026-03-20 15:15:35 +08:00
|
|
|
|
from app.core.config import settings
|
2026-03-18 17:18:23 +08:00
|
|
|
|
from app.core.db import AsyncSessionLocal
|
2026-03-20 16:36:42 +08:00
|
|
|
|
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
|
|
|
|
|
|
from app.core.redis import redis_service
|
2026-03-18 17:18:23 +08:00
|
|
|
|
from app.features.conversation.models import Conversation, Segment
|
|
|
|
|
|
from app.features.conversation.ws.connection_manager import manager
|
2026-03-19 14:36:14 +08:00
|
|
|
|
from app.features.conversation.ws.message_types import (
|
|
|
|
|
|
LEGACY_VOICE_SESSION_ID,
|
|
|
|
|
|
MessageType,
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
from app.features.conversation.ws.profile_collector import (
|
|
|
|
|
|
apply_extracted_profile,
|
|
|
|
|
|
get_filled_profile_fields,
|
|
|
|
|
|
get_missing_profile_fields,
|
|
|
|
|
|
)
|
|
|
|
|
|
from app.features.user.models import User
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
2026-03-19 09:58:02 +08:00
|
|
|
|
|
2026-03-20 16:36:42 +08:00
|
|
|
|
def _tts_object_ext(codec: str) -> str:
|
|
|
|
|
|
c = (codec or "mp3").lower().lstrip(".")
|
|
|
|
|
|
if c in ("wave",):
|
|
|
|
|
|
return "wav"
|
|
|
|
|
|
return c if c else "mp3"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tts_codec_to_content_type(codec: str) -> str:
|
|
|
|
|
|
c = (codec or "mp3").lower().lstrip(".")
|
|
|
|
|
|
if c == "mp3":
|
|
|
|
|
|
return "audio/mpeg"
|
|
|
|
|
|
if c in ("wav", "wave"):
|
|
|
|
|
|
return "audio/wav"
|
|
|
|
|
|
return "application/octet-stream"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _send_tts_audio(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
text: str,
|
|
|
|
|
|
*,
|
|
|
|
|
|
chunk_index: int,
|
|
|
|
|
|
chunk_total: int,
|
|
|
|
|
|
) -> str | None:
|
|
|
|
|
|
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
|
2026-03-20 15:15:35 +08:00
|
|
|
|
if not settings.enable_tts:
|
2026-03-20 16:36:42 +08:00
|
|
|
|
return None
|
2026-03-19 09:58:02 +08:00
|
|
|
|
try:
|
|
|
|
|
|
tts = get_tts_provider()
|
|
|
|
|
|
audio_bytes = await tts.synthesize(text)
|
|
|
|
|
|
if not audio_bytes:
|
|
|
|
|
|
logger.warning(
|
|
|
|
|
|
"TTS skipped: synthesize returned empty. Check TTS config in .env"
|
|
|
|
|
|
)
|
2026-03-20 16:36:42 +08:00
|
|
|
|
return None
|
|
|
|
|
|
ext = _tts_object_ext(settings.tts_codec)
|
|
|
|
|
|
content_type = _tts_codec_to_content_type(settings.tts_codec)
|
|
|
|
|
|
storage = get_object_storage()
|
|
|
|
|
|
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
|
|
|
|
|
|
public_url = storage.upload(key, audio_bytes, content_type)
|
|
|
|
|
|
await redis_service.append_tts_audio_url_to_last_ai_message(
|
|
|
|
|
|
conversation_id, public_url
|
|
|
|
|
|
)
|
2026-03-19 14:36:14 +08:00
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.TTS_AUDIO,
|
|
|
|
|
|
"conversation_id": conversation_id,
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
|
|
|
|
|
"format": settings.tts_codec,
|
2026-03-20 16:36:42 +08:00
|
|
|
|
"audio_url": public_url,
|
|
|
|
|
|
"index": chunk_index,
|
|
|
|
|
|
"total": chunk_total,
|
2026-03-19 14:36:14 +08:00
|
|
|
|
},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
2026-03-19 09:58:02 +08:00
|
|
|
|
},
|
2026-03-19 14:36:14 +08:00
|
|
|
|
)
|
2026-03-20 16:36:42 +08:00
|
|
|
|
return public_url
|
2026-03-19 09:58:02 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
err_str = str(e)
|
|
|
|
|
|
if "PkgExhausted" in err_str:
|
|
|
|
|
|
logger.warning(
|
|
|
|
|
|
"TTS skipped: 腾讯云语音合成资源包已用尽,请在控制台购买或开通后付费: %s",
|
|
|
|
|
|
err_str[:100],
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.error("TTS synthesize failed: %s", e)
|
2026-03-20 16:36:42 +08:00
|
|
|
|
return None
|
2026-03-19 09:58:02 +08:00
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
|
|
|
|
|
|
conversation_agent = ConversationAgent()
|
2026-03-19 10:36:55 +08:00
|
|
|
|
chat_orchestrator = ChatOrchestrator()
|
2026-03-18 17:18:23 +08:00
|
|
|
|
memory_agent = MemoryAgent()
|
|
|
|
|
|
background_runner = BackgroundTaskRunner()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── 分段流状态 ──────────────────────────────────────────────────
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
@dataclass
|
|
|
|
|
|
class SegmentStreamState:
|
|
|
|
|
|
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
|
|
|
|
|
|
|
|
|
|
|
|
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
|
|
|
|
pending_indices: Set[int] = field(default_factory=set)
|
|
|
|
|
|
processed_indices: Set[int] = field(default_factory=set)
|
|
|
|
|
|
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
|
|
|
|
|
|
consumed_index: int = -1
|
|
|
|
|
|
active_tasks: Set[asyncio.Task] = field(default_factory=set)
|
|
|
|
|
|
listening_feedback_sent: bool = False
|
|
|
|
|
|
listening_feedback_task: Optional[asyncio.Task] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_segment_states: Dict[Tuple[str, str], SegmentStreamState] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_or_create_segment_state(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
voice_session_id: str,
|
|
|
|
|
|
) -> SegmentStreamState:
|
|
|
|
|
|
state_key = (conversation_id, voice_session_id)
|
|
|
|
|
|
if state_key not in _segment_states:
|
|
|
|
|
|
_segment_states[state_key] = SegmentStreamState()
|
|
|
|
|
|
return _segment_states[state_key]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_segment_task(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
voice_session_id: str,
|
|
|
|
|
|
task: asyncio.Task,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
state_key = (conversation_id, voice_session_id)
|
|
|
|
|
|
state = get_or_create_segment_state(conversation_id, voice_session_id)
|
|
|
|
|
|
state.active_tasks.add(task)
|
|
|
|
|
|
|
|
|
|
|
|
def _cleanup(done_task: asyncio.Task) -> None:
|
|
|
|
|
|
state.active_tasks.discard(done_task)
|
|
|
|
|
|
if not state.active_tasks and conversation_id not in manager.active_connections:
|
|
|
|
|
|
_segment_states.pop(state_key, None)
|
|
|
|
|
|
if done_task.cancelled():
|
|
|
|
|
|
return
|
|
|
|
|
|
exc = done_task.exception()
|
|
|
|
|
|
if exc:
|
|
|
|
|
|
logger.error(
|
|
|
|
|
|
"分段处理任务异常 "
|
|
|
|
|
|
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
|
|
|
|
|
|
exc_info=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
task.add_done_callback(_cleanup)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cleanup_segment_states(conversation_id: str) -> None:
|
|
|
|
|
|
"""断开连接后清理无活跃任务的分段状态"""
|
|
|
|
|
|
stale_keys = [
|
|
|
|
|
|
key
|
|
|
|
|
|
for key, state in _segment_states.items()
|
|
|
|
|
|
if key[0] == conversation_id and not state.active_tasks
|
|
|
|
|
|
]
|
|
|
|
|
|
for key in stale_keys:
|
|
|
|
|
|
_segment_states.pop(key, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── 工具函数 ────────────────────────────────────────────────────
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
def _utc_now() -> datetime:
|
|
|
|
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
def _mark_conversation_active(
|
|
|
|
|
|
conversation: Conversation, at: Optional[datetime] = None
|
|
|
|
|
|
) -> datetime:
|
2026-03-18 17:18:23 +08:00
|
|
|
|
activity_time = at or _utc_now()
|
|
|
|
|
|
conversation.last_message_at = activity_time
|
|
|
|
|
|
return activity_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
|
|
|
|
|
|
if voice_session_id:
|
|
|
|
|
|
return str(voice_session_id)
|
|
|
|
|
|
return LEGACY_VOICE_SESSION_ID
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
def _voice_session_id_from_client_segment_id(
|
|
|
|
|
|
client_segment_id: Optional[str],
|
|
|
|
|
|
) -> Optional[str]:
|
2026-03-18 17:18:23 +08:00
|
|
|
|
if not client_segment_id:
|
|
|
|
|
|
return None
|
|
|
|
|
|
session_id, separator, _ = client_segment_id.rpartition("-")
|
|
|
|
|
|
if separator and session_id:
|
|
|
|
|
|
return session_id
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
|
|
|
|
|
|
"""构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。"""
|
|
|
|
|
|
return f"audio-segment:{voice_session_id}:{segment_index}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
|
|
|
|
|
|
"""从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}。"""
|
|
|
|
|
|
prefix = "audio-segment:"
|
|
|
|
|
|
if not audio_url or not audio_url.startswith(prefix):
|
|
|
|
|
|
return None
|
2026-03-19 14:36:14 +08:00
|
|
|
|
payload = audio_url[len(prefix) :]
|
2026-03-18 17:18:23 +08:00
|
|
|
|
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if separator:
|
2026-03-19 14:36:14 +08:00
|
|
|
|
return (
|
|
|
|
|
|
_normalize_voice_session_id(voice_session_id_raw),
|
|
|
|
|
|
int(segment_index_raw),
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
return (LEGACY_VOICE_SESSION_ID, int(payload))
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
scope = _extract_segment_scope(audio_url)
|
|
|
|
|
|
if scope:
|
|
|
|
|
|
return scope[0]
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
|
|
|
|
|
|
if not transcript_text:
|
|
|
|
|
|
return True
|
|
|
|
|
|
return transcript_text.startswith("转写失败")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _find_existing_segment_by_index(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
voice_session_id: str,
|
|
|
|
|
|
segment_index: int,
|
|
|
|
|
|
) -> Optional[Segment]:
|
|
|
|
|
|
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
|
2026-03-19 14:36:14 +08:00
|
|
|
|
stmt = (
|
|
|
|
|
|
select(Segment)
|
|
|
|
|
|
.where(
|
|
|
|
|
|
Segment.conversation_id == conversation_id,
|
|
|
|
|
|
Segment.audio_url == segment_audio_url,
|
|
|
|
|
|
)
|
|
|
|
|
|
.order_by(Segment.created_at.desc())
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
candidates = result.scalars().all()
|
|
|
|
|
|
for item in candidates:
|
2026-03-19 14:36:14 +08:00
|
|
|
|
if (
|
|
|
|
|
|
item.conversation_id == conversation_id
|
|
|
|
|
|
and item.audio_url == segment_audio_url
|
|
|
|
|
|
):
|
2026-03-18 17:18:23 +08:00
|
|
|
|
return item
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_persisted_contiguous_segment_index(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
voice_session_id: str,
|
|
|
|
|
|
) -> int:
|
|
|
|
|
|
"""读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。"""
|
|
|
|
|
|
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
candidates = result.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
persisted_indices: Set[int] = set()
|
|
|
|
|
|
for item in candidates:
|
|
|
|
|
|
if item.conversation_id != conversation_id:
|
|
|
|
|
|
continue
|
|
|
|
|
|
segment_scope = _extract_segment_scope(item.audio_url)
|
|
|
|
|
|
if not segment_scope:
|
|
|
|
|
|
continue
|
|
|
|
|
|
item_voice_session_id, item_index = segment_scope
|
|
|
|
|
|
if item_voice_session_id != voice_session_id:
|
|
|
|
|
|
continue
|
|
|
|
|
|
persisted_indices.add(item_index)
|
|
|
|
|
|
|
|
|
|
|
|
contiguous_index = -1
|
|
|
|
|
|
while contiguous_index + 1 in persisted_indices:
|
|
|
|
|
|
contiguous_index += 1
|
|
|
|
|
|
return contiguous_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── 过渡反馈 ────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
LISTENING_FEEDBACK_DELAY_SEC = 5.0
|
|
|
|
|
|
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _send_segment_transition_feedback(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
segment_index: int,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
|
2026-03-19 14:36:14 +08:00
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.AGENT_RESPONSE,
|
|
|
|
|
|
"conversation_id": conversation_id,
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text": LISTENING_FEEDBACK_TEXT,
|
|
|
|
|
|
"transition": True,
|
|
|
|
|
|
"segment_index": segment_index,
|
|
|
|
|
|
},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
2026-03-18 17:18:23 +08:00
|
|
|
|
},
|
2026-03-19 14:36:14 +08:00
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _delayed_listening_feedback(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
voice_session_id: str,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
|
|
|
|
|
|
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
|
|
|
|
|
|
state = get_or_create_segment_state(conversation_id, voice_session_id)
|
|
|
|
|
|
async with state.lock:
|
|
|
|
|
|
if state.listening_feedback_sent:
|
|
|
|
|
|
return
|
|
|
|
|
|
state.listening_feedback_sent = True
|
|
|
|
|
|
state.listening_feedback_task = None
|
|
|
|
|
|
await _send_segment_transition_feedback(conversation_id, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── 分段语音异步处理 ────────────────────────────────────────────
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
async def process_audio_segment(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
voice_session_id: str,
|
|
|
|
|
|
segment_index: int,
|
|
|
|
|
|
audio_base64: str,
|
|
|
|
|
|
audio_duration: int,
|
|
|
|
|
|
is_last: bool,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
|
|
|
|
|
|
state = get_or_create_segment_state(conversation_id, voice_session_id)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
|
|
|
|
conversation = await db.get(Conversation, conversation_id)
|
|
|
|
|
|
user = await db.get(User, user_id)
|
2026-03-20 15:15:35 +08:00
|
|
|
|
if not conversation or conversation.deleted_at is not None:
|
2026-03-19 14:36:14 +08:00
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.ERROR,
|
|
|
|
|
|
"data": {"message": "对话不存在,分段处理已取消"},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
return
|
|
|
|
|
|
if not user:
|
2026-03-19 14:36:14 +08:00
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.ERROR,
|
|
|
|
|
|
"data": {"message": "用户不存在,分段处理已取消"},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
async with state.lock:
|
|
|
|
|
|
should_prime_state = (
|
|
|
|
|
|
state.consumed_index < 0
|
|
|
|
|
|
and not state.processed_indices
|
|
|
|
|
|
and not state.buffered_transcripts
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if should_prime_state:
|
2026-03-19 14:36:14 +08:00
|
|
|
|
persisted_contiguous_index = (
|
|
|
|
|
|
await _get_persisted_contiguous_segment_index(
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
voice_session_id=voice_session_id,
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
if persisted_contiguous_index >= 0:
|
|
|
|
|
|
async with state.lock:
|
2026-03-19 14:36:14 +08:00
|
|
|
|
state.consumed_index = max(
|
|
|
|
|
|
state.consumed_index, persisted_contiguous_index
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
audio_bytes = base64.b64decode(audio_base64)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
audio_bytes = b""
|
|
|
|
|
|
transcript_text = await get_asr_provider().transcribe(
|
|
|
|
|
|
audio_bytes, format="m4a"
|
|
|
|
|
|
)
|
2026-03-19 14:36:14 +08:00
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.TRANSCRIPT,
|
|
|
|
|
|
"conversation_id": conversation_id,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
"data": {
|
2026-03-19 14:36:14 +08:00
|
|
|
|
"text": transcript_text or "",
|
|
|
|
|
|
"audio_duration": audio_duration,
|
|
|
|
|
|
"voice_session_id": voice_session_id,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
"segment_index": segment_index,
|
2026-03-19 14:36:14 +08:00
|
|
|
|
"is_last": is_last,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
2026-03-19 14:36:14 +08:00
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if _is_transcribe_failure(transcript_text):
|
|
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.ERROR,
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"message": f"分段 {segment_index} 转写失败,请重试该片段",
|
|
|
|
|
|
"segment_index": segment_index,
|
|
|
|
|
|
},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
existing_segment = await _find_existing_segment_by_index(
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
voice_session_id=voice_session_id,
|
|
|
|
|
|
segment_index=segment_index,
|
|
|
|
|
|
)
|
|
|
|
|
|
if existing_segment:
|
|
|
|
|
|
async with state.lock:
|
|
|
|
|
|
state.processed_indices.add(segment_index)
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
"分段已存在,按幂等处理跳过: "
|
|
|
|
|
|
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
|
|
|
|
|
)
|
|
|
|
|
|
return
|
|
|
|
|
|
else:
|
|
|
|
|
|
segment = Segment(
|
|
|
|
|
|
id=str(uuid.uuid4()),
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
transcript_text=transcript_text or "",
|
|
|
|
|
|
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
|
2026-03-20 16:36:42 +08:00
|
|
|
|
audio_duration_seconds=audio_duration
|
|
|
|
|
|
if audio_duration > 0
|
|
|
|
|
|
else None,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
processed=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(segment)
|
|
|
|
|
|
user_message_timestamp = _mark_conversation_active(conversation)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
await db.refresh(segment)
|
|
|
|
|
|
await background_runner.queue_message(conversation.user_id, segment.id)
|
|
|
|
|
|
|
|
|
|
|
|
ready_segments: List[Tuple[int, str, Segment]] = []
|
|
|
|
|
|
async with state.lock:
|
|
|
|
|
|
state.processed_indices.add(segment_index)
|
2026-03-19 14:36:14 +08:00
|
|
|
|
state.buffered_transcripts[segment_index] = (
|
|
|
|
|
|
transcript_text or "",
|
|
|
|
|
|
segment,
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
|
|
|
|
|
|
next_index = state.consumed_index + 1
|
|
|
|
|
|
while next_index in state.buffered_transcripts:
|
|
|
|
|
|
text, seg = state.buffered_transcripts.pop(next_index)
|
|
|
|
|
|
ready_segments.append((next_index, text, seg))
|
|
|
|
|
|
state.consumed_index = next_index
|
|
|
|
|
|
next_index += 1
|
|
|
|
|
|
|
|
|
|
|
|
for _, ordered_text, ordered_segment in ready_segments:
|
|
|
|
|
|
await process_user_message(
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
user_message=ordered_text,
|
|
|
|
|
|
conversation=conversation,
|
|
|
|
|
|
segment=ordered_segment,
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
user=user,
|
2026-03-19 14:36:14 +08:00
|
|
|
|
user_message_timestamp=ordered_segment.created_at
|
|
|
|
|
|
or user_message_timestamp,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(
|
|
|
|
|
|
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
|
|
|
|
|
|
exc_info=True,
|
|
|
|
|
|
)
|
2026-03-19 14:36:14 +08:00
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.ERROR,
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"message": f"分段处理失败: {str(e)}",
|
|
|
|
|
|
"segment_index": segment_index,
|
|
|
|
|
|
},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
2026-03-18 17:18:23 +08:00
|
|
|
|
},
|
2026-03-19 14:36:14 +08:00
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
finally:
|
|
|
|
|
|
async with state.lock:
|
|
|
|
|
|
state.pending_indices.discard(segment_index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── 用户消息处理 ────────────────────────────────────────────────
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
async def process_user_message(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
user_message: str,
|
|
|
|
|
|
conversation: Conversation,
|
|
|
|
|
|
segment: Segment,
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
user: User = None,
|
|
|
|
|
|
user_message_timestamp: Optional[datetime] = None,
|
|
|
|
|
|
) -> None:
|
2026-03-19 10:36:55 +08:00
|
|
|
|
"""处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。"""
|
2026-03-18 17:18:23 +08:00
|
|
|
|
try:
|
|
|
|
|
|
is_from_voice = bool(segment.audio_url)
|
2026-03-19 10:36:55 +08:00
|
|
|
|
voice_session_id = _voice_session_id_from_audio_url(segment.audio_url)
|
2026-03-20 16:36:42 +08:00
|
|
|
|
audio_dur = getattr(segment, "audio_duration_seconds", None)
|
2026-03-19 10:36:55 +08:00
|
|
|
|
responses = await chat_orchestrator.process_user_message(
|
2026-03-18 17:18:23 +08:00
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
user_message=user_message,
|
2026-03-19 10:36:55 +08:00
|
|
|
|
user=user,
|
|
|
|
|
|
conversation=conversation,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
is_from_voice=is_from_voice,
|
2026-03-19 10:36:55 +08:00
|
|
|
|
voice_session_id=voice_session_id,
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
apply_extracted_profile_fn=apply_extracted_profile,
|
|
|
|
|
|
get_missing_profile_fields_fn=get_missing_profile_fields,
|
|
|
|
|
|
get_filled_profile_fields_fn=get_filled_profile_fields,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
user_message_timestamp=user_message_timestamp,
|
2026-03-20 16:36:42 +08:00
|
|
|
|
audio_duration_seconds=audio_dur,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
segment.agent_response = "\n\n".join(responses)
|
|
|
|
|
|
_mark_conversation_active(conversation)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
2026-03-20 16:36:42 +08:00
|
|
|
|
tts_urls: list[str] = []
|
|
|
|
|
|
n = len(responses)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
for i, response_text in enumerate(responses):
|
2026-03-19 14:36:14 +08:00
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.AGENT_RESPONSE,
|
|
|
|
|
|
"conversation_id": conversation_id,
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text": response_text,
|
|
|
|
|
|
"index": i,
|
2026-03-20 16:36:42 +08:00
|
|
|
|
"total": n,
|
2026-03-19 14:36:14 +08:00
|
|
|
|
},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
2026-03-20 16:36:42 +08:00
|
|
|
|
url = await _send_tts_audio(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
response_text,
|
|
|
|
|
|
chunk_index=i,
|
|
|
|
|
|
chunk_total=n,
|
|
|
|
|
|
)
|
|
|
|
|
|
if url:
|
|
|
|
|
|
tts_urls.append(url)
|
|
|
|
|
|
if i < n - 1:
|
2026-03-18 17:18:23 +08:00
|
|
|
|
await asyncio.sleep(0.5)
|
|
|
|
|
|
|
2026-03-20 16:36:42 +08:00
|
|
|
|
await db.execute(
|
|
|
|
|
|
update(Segment)
|
|
|
|
|
|
.where(Segment.id == segment.id)
|
|
|
|
|
|
.values(tts_audio_urls=tts_urls if tts_urls else None)
|
|
|
|
|
|
)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"处理用户消息失败: {e}", exc_info=True)
|
|
|
|
|
|
if conversation_id in manager.active_connections:
|
|
|
|
|
|
try:
|
2026-03-19 14:36:14 +08:00
|
|
|
|
await manager.send_message(
|
|
|
|
|
|
conversation_id,
|
|
|
|
|
|
{
|
|
|
|
|
|
"type": MessageType.ERROR,
|
|
|
|
|
|
"data": {"message": f"生成回应失败: {str(e)}"},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
except Exception as send_error:
|
|
|
|
|
|
logger.warning(f"发送错误消息失败: {send_error}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── 对话结束处理 ────────────────────────────────────────────────
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
async def process_conversation_segments(
|
|
|
|
|
|
conversation_id: str, db: AsyncSession, quota_service: "QuotaService"
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
处理对话段落,生成章节(对话结束时调用)
|
|
|
|
|
|
|
|
|
|
|
|
注意:大部分处理已通过 Celery 任务增量完成
|
|
|
|
|
|
这里立即提交所有待处理的段落到 Celery
|
|
|
|
|
|
配额检查通过注入的 quota_service 完成,不直接 import quota 内部函数。
|
|
|
|
|
|
"""
|
|
|
|
|
|
conversation = await db.get(Conversation, conversation_id)
|
2026-03-20 15:15:35 +08:00
|
|
|
|
if not conversation or conversation.deleted_at is not None:
|
2026-03-18 17:18:23 +08:00
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
stmt = select(Segment).where(
|
|
|
|
|
|
Segment.conversation_id == conversation_id,
|
|
|
|
|
|
Segment.processed == False,
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
segments = result.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
if not segments:
|
|
|
|
|
|
await background_runner.flush_pending(conversation.user_id)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
user = await db.get(User, conversation.user_id)
|
|
|
|
|
|
if user:
|
|
|
|
|
|
can_submit, _ = await quota_service.check_can_submit_organize(
|
|
|
|
|
|
user.id, user.subscription_type
|
|
|
|
|
|
)
|
|
|
|
|
|
if not can_submit:
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
|
|
|
|
|
|
)
|
|
|
|
|
|
await background_runner.flush_pending(conversation.user_id)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
segment_ids = [seg.id for seg in segments]
|
|
|
|
|
|
try:
|
|
|
|
|
|
from app.tasks.memoir_tasks import process_memoir_segments
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
process_memoir_segments.delay(conversation.user_id, segment_ids)
|
2026-03-19 14:36:14 +08:00
|
|
|
|
logger.info(
|
|
|
|
|
|
f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}"
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"提交 Celery 任务失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
await background_runner.flush_pending(conversation.user_id)
|