Files
life-echo/api/app/features/conversation/ws/pipeline.py

597 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""核心消息处理管道Agent 调用、ASR 转写、分段有序聚合"""
import asyncio
import base64
from app.core.logging import get_logger
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
if TYPE_CHECKING:
from app.features.quota.service import QuotaService
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents import ConversationAgent, MemoryAgent
from app.agents.memoir_processor import BackgroundTaskRunner
from app.agents.prompts.profile_prompts import format_user_profile_context
from app.core.db import AsyncSessionLocal
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import LEGACY_VOICE_SESSION_ID, MessageType
from app.features.conversation.ws.profile_collector import (
apply_extracted_profile,
get_filled_profile_fields,
get_missing_profile_fields,
)
from app.features.user.models import User
from app.core.dependencies import get_asr_provider, get_tts_provider
from app.features.memoir.state_service import get_or_create_state
logger = get_logger(__name__)
async def _send_tts_audio(conversation_id: str, text: str) -> None:
"""Synthesize text to speech and send TTS_AUDIO if successful."""
try:
tts = get_tts_provider()
audio_bytes = await tts.synthesize(text)
if not audio_bytes:
logger.warning(
"TTS skipped: synthesize returned empty. Check TTS config in .env"
)
return
await manager.send_message(conversation_id, {
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": "mp3",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
except Exception as e:
err_str = str(e)
if "PkgExhausted" in err_str:
logger.warning(
"TTS skipped: 腾讯云语音合成资源包已用尽,请在控制台购买或开通后付费: %s",
err_str[:100],
)
else:
logger.error("TTS synthesize failed: %s", e)
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
conversation_agent = ConversationAgent()
memory_agent = MemoryAgent()
background_runner = BackgroundTaskRunner()
# ── 分段流状态 ──────────────────────────────────────────────────
@dataclass
class SegmentStreamState:
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
pending_indices: Set[int] = field(default_factory=set)
processed_indices: Set[int] = field(default_factory=set)
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
consumed_index: int = -1
active_tasks: Set[asyncio.Task] = field(default_factory=set)
listening_feedback_sent: bool = False
listening_feedback_task: Optional[asyncio.Task] = None
_segment_states: Dict[Tuple[str, str], SegmentStreamState] = {}
def get_or_create_segment_state(
conversation_id: str,
voice_session_id: str,
) -> SegmentStreamState:
state_key = (conversation_id, voice_session_id)
if state_key not in _segment_states:
_segment_states[state_key] = SegmentStreamState()
return _segment_states[state_key]
def register_segment_task(
conversation_id: str,
voice_session_id: str,
task: asyncio.Task,
) -> None:
state_key = (conversation_id, voice_session_id)
state = get_or_create_segment_state(conversation_id, voice_session_id)
state.active_tasks.add(task)
def _cleanup(done_task: asyncio.Task) -> None:
state.active_tasks.discard(done_task)
if not state.active_tasks and conversation_id not in manager.active_connections:
_segment_states.pop(state_key, None)
if done_task.cancelled():
return
exc = done_task.exception()
if exc:
logger.error(
"分段处理任务异常 "
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
exc_info=True,
)
task.add_done_callback(_cleanup)
def cleanup_segment_states(conversation_id: str) -> None:
"""断开连接后清理无活跃任务的分段状态"""
stale_keys = [
key
for key, state in _segment_states.items()
if key[0] == conversation_id and not state.active_tasks
]
for key in stale_keys:
_segment_states.pop(key, None)
# ── 工具函数 ────────────────────────────────────────────────────
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime:
activity_time = at or _utc_now()
conversation.last_message_at = activity_time
return activity_time
def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
if voice_session_id:
return str(voice_session_id)
return LEGACY_VOICE_SESSION_ID
def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]:
if not client_segment_id:
return None
session_id, separator, _ = client_segment_id.rpartition("-")
if separator and session_id:
return session_id
return None
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
"""构建分段语音的幂等标识conversation_id + voice_session_id + segment_index"""
return f"audio-segment:{voice_session_id}:{segment_index}"
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
"""从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}"""
prefix = "audio-segment:"
if not audio_url or not audio_url.startswith(prefix):
return None
payload = audio_url[len(prefix):]
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
try:
if separator:
return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw))
return (LEGACY_VOICE_SESSION_ID, int(payload))
except ValueError:
return None
def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]:
scope = _extract_segment_scope(audio_url)
if scope:
return scope[0]
return None
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
if not transcript_text:
return True
return transcript_text.startswith("转写失败")
async def _find_existing_segment_by_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
segment_index: int,
) -> Optional[Segment]:
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.audio_url == segment_audio_url,
).order_by(Segment.created_at.desc())
result = await db.execute(stmt)
candidates = result.scalars().all()
for item in candidates:
if item.conversation_id == conversation_id and item.audio_url == segment_audio_url:
return item
return None
async def _get_persisted_contiguous_segment_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
) -> int:
"""读取数据库中当前 voice session 已连续落库的最大 segment_index用于重连恢复。"""
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
result = await db.execute(stmt)
candidates = result.scalars().all()
persisted_indices: Set[int] = set()
for item in candidates:
if item.conversation_id != conversation_id:
continue
segment_scope = _extract_segment_scope(item.audio_url)
if not segment_scope:
continue
item_voice_session_id, item_index = segment_scope
if item_voice_session_id != voice_session_id:
continue
persisted_indices.add(item_index)
contiguous_index = -1
while contiguous_index + 1 in persisted_indices:
contiguous_index += 1
return contiguous_index
# ── 过渡反馈 ────────────────────────────────────────────────────
LISTENING_FEEDBACK_DELAY_SEC = 5.0
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
async def _send_segment_transition_feedback(
conversation_id: str,
segment_index: int,
) -> None:
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": LISTENING_FEEDBACK_TEXT,
"transition": True,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
async def _delayed_listening_feedback(
conversation_id: str,
voice_session_id: str,
) -> None:
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
state = get_or_create_segment_state(conversation_id, voice_session_id)
async with state.lock:
if state.listening_feedback_sent:
return
state.listening_feedback_sent = True
state.listening_feedback_task = None
await _send_segment_transition_feedback(conversation_id, 0)
# ── 分段语音异步处理 ────────────────────────────────────────────
async def process_audio_segment(
conversation_id: str,
user_id: str,
voice_session_id: str,
segment_index: int,
audio_base64: str,
audio_duration: int,
is_last: bool,
) -> None:
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
state = get_or_create_segment_state(conversation_id, voice_session_id)
try:
async with AsyncSessionLocal() as db:
conversation = await db.get(Conversation, conversation_id)
user = await db.get(User, user_id)
if not conversation:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "对话不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
if not user:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "用户不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
async with state.lock:
should_prime_state = (
state.consumed_index < 0
and not state.processed_indices
and not state.buffered_transcripts
)
if should_prime_state:
persisted_contiguous_index = await _get_persisted_contiguous_segment_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
if persisted_contiguous_index >= 0:
async with state.lock:
state.consumed_index = max(state.consumed_index, persisted_contiguous_index)
try:
audio_bytes = base64.b64decode(audio_base64)
except Exception:
audio_bytes = b""
transcript_text = await get_asr_provider().transcribe(
audio_bytes, format="m4a"
)
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text or "",
"audio_duration": audio_duration,
"voice_session_id": voice_session_id,
"segment_index": segment_index,
"is_last": is_last,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
if _is_transcribe_failure(transcript_text):
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {
"message": f"分段 {segment_index} 转写失败,请重试该片段",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
return
existing_segment = await _find_existing_segment_by_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
)
if existing_segment:
async with state.lock:
state.processed_indices.add(segment_index)
logger.info(
"分段已存在,按幂等处理跳过: "
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
)
return
else:
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=transcript_text or "",
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
processed=False,
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(conversation)
await db.commit()
await db.refresh(segment)
await background_runner.queue_message(conversation.user_id, segment.id)
ready_segments: List[Tuple[int, str, Segment]] = []
async with state.lock:
state.processed_indices.add(segment_index)
state.buffered_transcripts[segment_index] = (transcript_text or "", segment)
next_index = state.consumed_index + 1
while next_index in state.buffered_transcripts:
text, seg = state.buffered_transcripts.pop(next_index)
ready_segments.append((next_index, text, seg))
state.consumed_index = next_index
next_index += 1
for _, ordered_text, ordered_segment in ready_segments:
await process_user_message(
conversation_id=conversation_id,
user_message=ordered_text,
conversation=conversation,
segment=ordered_segment,
db=db,
user=user,
user_message_timestamp=ordered_segment.created_at or user_message_timestamp,
)
except Exception as e:
logger.error(
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
exc_info=True,
)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {
"message": f"分段处理失败: {str(e)}",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
finally:
async with state.lock:
state.pending_indices.discard(segment_index)
# ── 用户消息处理 ────────────────────────────────────────────────
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
user: User = None,
user_message_timestamp: Optional[datetime] = None,
) -> None:
"""处理用户消息,生成 Agent 回应。支持资料收集模式和正式访谈模式。"""
agent = conversation_agent
if user:
missing = get_missing_profile_fields(user)
if missing:
try:
extracted = await agent.extract_profile_from_message(
user_message, missing, conversation_id=conversation_id
)
if extracted:
await apply_extracted_profile(user, extracted, db)
remaining = get_missing_profile_fields(user)
filled = get_filled_profile_fields(user)
is_from_voice = bool(segment.audio_url)
responses = await agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=remaining,
filled_fields=filled,
nickname=user.nickname or "",
is_from_voice=is_from_voice,
voice_session_id=_voice_session_id_from_audio_url(segment.audio_url),
user_message_timestamp=user_message_timestamp,
)
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
await db.commit()
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await _send_tts_audio(conversation_id, response_text)
if i < len(responses) - 1:
await asyncio.sleep(0.5)
return
except Exception as e:
logger.error(f"资料收集处理失败: {e}", exc_info=True)
state = await get_or_create_state(conversation.user_id, db)
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
user_profile_context = ""
if user:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
try:
is_from_voice = bool(segment.audio_url)
responses = await agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
user_profile_context=user_profile_context,
is_from_voice=is_from_voice,
voice_session_id=_voice_session_id_from_audio_url(segment.audio_url),
user_message_timestamp=user_message_timestamp,
)
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
await db.commit()
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await _send_tts_audio(conversation_id, response_text)
if i < len(responses) - 1:
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"处理用户消息失败: {e}", exc_info=True)
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
# ── 对话结束处理 ────────────────────────────────────────────────
async def process_conversation_segments(
conversation_id: str, db: AsyncSession, quota_service: "QuotaService"
):
"""
处理对话段落,生成章节(对话结束时调用)
注意:大部分处理已通过 Celery 任务增量完成
这里立即提交所有待处理的段落到 Celery
配额检查通过注入的 quota_service 完成,不直接 import quota 内部函数。
"""
conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False,
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
await background_runner.flush_pending(conversation.user_id)
return
user = await db.get(User, conversation.user_id)
if user:
can_submit, _ = await quota_service.check_can_submit_organize(
user.id, user.subscription_type
)
if not can_submit:
logger.info(
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
)
await background_runner.flush_pending(conversation.user_id)
return
segment_ids = [seg.id for seg in segments]
try:
from app.tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(conversation.user_id, segment_ids)
logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}")
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
await background_runner.flush_pending(conversation.user_id)