Merge branch 'refactor/backend-architecture' into development
This commit is contained in:
481
api/app/features/conversation/ws/router.py
Normal file
481
api/app/features/conversation/ws/router.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""
|
||||
WebSocket 路由:实时对话通信
|
||||
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
|
||||
"""
|
||||
import asyncio
|
||||
from app.core.logging import get_logger
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect, status
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from app.agents.prompts.profile_prompts import format_user_profile_context
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.core.security import verify_token
|
||||
from app.features.conversation.models import Conversation, Segment
|
||||
from app.features.conversation.ws.connection_manager import manager
|
||||
from app.features.conversation.ws.message_types import MessageType
|
||||
from app.features.conversation.ws.pipeline import (
|
||||
SegmentStreamState, # noqa: F401 — re-export for test backward compat
|
||||
_delayed_listening_feedback,
|
||||
_mark_conversation_active,
|
||||
_normalize_voice_session_id,
|
||||
_voice_session_id_from_client_segment_id,
|
||||
background_runner,
|
||||
cleanup_segment_states,
|
||||
conversation_agent,
|
||||
get_or_create_segment_state,
|
||||
process_audio_segment,
|
||||
process_conversation_segments,
|
||||
process_user_message,
|
||||
register_segment_task,
|
||||
)
|
||||
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
|
||||
from app.features.conversation.ws.quota_guard import check_ws_quota
|
||||
from app.features.quota.service import QuotaService
|
||||
from app.features.user.models import User
|
||||
import base64
|
||||
|
||||
from app.core.dependencies import get_asr_provider
|
||||
from app.core.redis import redis_service
|
||||
from app.features.memoir.state_service import get_or_create_state
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str,
|
||||
):
|
||||
"""
|
||||
WebSocket 端点:处理实时对话
|
||||
|
||||
Args:
|
||||
websocket: WebSocket 连接
|
||||
conversation_id: 对话 ID
|
||||
"""
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
|
||||
return
|
||||
|
||||
payload = verify_token(token)
|
||||
if not payload or payload.get("type") != "access":
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
|
||||
return
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
|
||||
return
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
user = await db.get(User, user_id)
|
||||
if not user:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
|
||||
return
|
||||
|
||||
await manager.connect(websocket, conversation_id)
|
||||
|
||||
quota_service = QuotaService(db=db)
|
||||
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.CONNECT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"status": "connected"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
user_id=user_id,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
status="active",
|
||||
)
|
||||
db.add(conversation)
|
||||
await db.commit()
|
||||
else:
|
||||
if conversation.user_id != user_id:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "无权访问此对话"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
|
||||
return
|
||||
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
if not history:
|
||||
missing_profile = get_missing_profile_fields(user)
|
||||
if missing_profile:
|
||||
try:
|
||||
greetings = await conversation_agent.generate_profile_greeting(
|
||||
conversation_id=conversation_id,
|
||||
missing_fields=missing_profile,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
for i, text in enumerate(greetings):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": text, "index": i, "total": len(greetings)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
if i < len(greetings) - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
|
||||
else:
|
||||
try:
|
||||
state = await get_or_create_state(user_id, db)
|
||||
user_profile_context = format_user_profile_context(
|
||||
birth_year=user.birth_year,
|
||||
birth_place=user.birth_place,
|
||||
grew_up_place=user.grew_up_place,
|
||||
occupation=user.occupation,
|
||||
)
|
||||
opening_messages = await conversation_agent.generate_opening_message(
|
||||
conversation_id=conversation_id,
|
||||
memoir_state=state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
for i, text in enumerate(opening_messages):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": text, "index": i, "total": len(opening_messages)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
if i < len(opening_messages) - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"发送空对话开场白失败: {e}", exc_info=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if websocket.application_state != WebSocketState.CONNECTED:
|
||||
logger.info(f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}")
|
||||
break
|
||||
message = await websocket.receive_json()
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type == MessageType.TEXT:
|
||||
text_message = message.get("data", {}).get("text", "")
|
||||
|
||||
if text_message:
|
||||
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
|
||||
if not can_send:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
continue
|
||||
|
||||
segment = Segment(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
transcript_text=text_message,
|
||||
processed=False,
|
||||
)
|
||||
db.add(segment)
|
||||
user_message_timestamp = _mark_conversation_active(conversation)
|
||||
await db.commit()
|
||||
await db.refresh(segment)
|
||||
await background_runner.queue_message(conversation.user_id, segment.id)
|
||||
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=text_message,
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
user=user,
|
||||
user_message_timestamp=segment.created_at or user_message_timestamp,
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.RECORDING_STARTED:
|
||||
data = message.get("data", {})
|
||||
voice_session_id = _normalize_voice_session_id(data.get("voice_session_id"))
|
||||
segment_state = get_or_create_segment_state(
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
)
|
||||
async with segment_state.lock:
|
||||
if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done():
|
||||
continue
|
||||
if segment_state.listening_feedback_sent:
|
||||
continue
|
||||
delayed_task = asyncio.create_task(
|
||||
_delayed_listening_feedback(
|
||||
conversation_id=conversation_id,
|
||||
voice_session_id=voice_session_id,
|
||||
)
|
||||
)
|
||||
segment_state.listening_feedback_task = delayed_task
|
||||
|
||||
elif msg_type == MessageType.AUDIO_SEGMENT:
|
||||
data = message.get("data", {})
|
||||
audio_base64 = data.get("audio_base64", "")
|
||||
segment_index_raw = data.get("segment_index")
|
||||
voice_session_id = _normalize_voice_session_id(
|
||||
data.get("voice_session_id")
|
||||
or _voice_session_id_from_client_segment_id(data.get("client_segment_id"))
|
||||
)
|
||||
is_last = bool(data.get("is_last", False))
|
||||
audio_duration = int(data.get("duration", 0) or 0)
|
||||
|
||||
if not audio_base64:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 audio_base64"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
continue
|
||||
|
||||
if segment_index_raw is None:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 segment_index"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
segment_index = int(segment_index_raw)
|
||||
except (TypeError, ValueError):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 必须为整数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
continue
|
||||
|
||||
if segment_index < 0:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 不能为负数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
continue
|
||||
|
||||
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
|
||||
if not can_send:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
continue
|
||||
|
||||
segment_state = get_or_create_segment_state(
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
)
|
||||
should_process = False
|
||||
async with segment_state.lock:
|
||||
already_seen = (
|
||||
segment_index in segment_state.pending_indices
|
||||
or segment_index in segment_state.processed_indices
|
||||
or segment_index <= segment_state.consumed_index
|
||||
)
|
||||
if not already_seen:
|
||||
segment_state.pending_indices.add(segment_index)
|
||||
should_process = True
|
||||
|
||||
if not should_process:
|
||||
logger.info(
|
||||
"收到重复分段,跳过处理: "
|
||||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||||
)
|
||||
continue
|
||||
|
||||
if is_last:
|
||||
async with segment_state.lock:
|
||||
t = segment_state.listening_feedback_task
|
||||
segment_state.listening_feedback_task = None
|
||||
if t is not None and not t.done():
|
||||
t.cancel()
|
||||
|
||||
task = asyncio.create_task(
|
||||
process_audio_segment(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
voice_session_id=voice_session_id,
|
||||
segment_index=segment_index,
|
||||
audio_base64=audio_base64,
|
||||
audio_duration=audio_duration,
|
||||
is_last=is_last,
|
||||
)
|
||||
)
|
||||
register_segment_task(conversation_id, voice_session_id, task)
|
||||
|
||||
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||||
data = message.get("data", {})
|
||||
audio_base64 = data.get("audio_base64", "")
|
||||
audio_duration = data.get("duration", 0)
|
||||
|
||||
if audio_base64:
|
||||
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
|
||||
if not can_send:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
continue
|
||||
|
||||
logger.info(f"收到音频消息,时长: {audio_duration}s")
|
||||
|
||||
try:
|
||||
asr = get_asr_provider()
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
transcript_text = await asr.transcribe(audio_bytes, "m4a")
|
||||
logger.info("ASR 转写结果: %s", transcript_text)
|
||||
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": transcript_text,
|
||||
"audio_duration": audio_duration,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
segment = Segment(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
transcript_text=transcript_text,
|
||||
audio_url=f"audio:{audio_duration}s",
|
||||
processed=False,
|
||||
)
|
||||
db.add(segment)
|
||||
user_message_timestamp = _mark_conversation_active(conversation)
|
||||
await db.commit()
|
||||
await db.refresh(segment)
|
||||
await background_runner.queue_message(conversation.user_id, segment.id)
|
||||
|
||||
if transcript_text and not transcript_text.startswith("转写失败"):
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=transcript_text,
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
user=user,
|
||||
user_message_timestamp=segment.created_at or user_message_timestamp,
|
||||
)
|
||||
else:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "语音转写失败,请重试或使用文字输入"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理音频消息失败: {e}", exc_info=True)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"处理音频消息失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
elif msg_type == MessageType.TRANSCRIBE_ONLY:
|
||||
data = message.get("data", {})
|
||||
audio_base64 = data.get("audio_base64", "")
|
||||
if not audio_base64:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 audio_base64"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
continue
|
||||
try:
|
||||
asr = get_asr_provider()
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
transcript_text = await asr.transcribe(audio_bytes, "m4a")
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": transcript_text or ""},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"仅转写失败: {e}", exc_info=True)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"转写失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
elif msg_type == MessageType.END_CONVERSATION:
|
||||
conversation.status = "ended"
|
||||
conversation.ended_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
await process_conversation_segments(conversation_id, db, quota_service)
|
||||
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.END_CONVERSATION,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"status": "ended"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
break
|
||||
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e)
|
||||
if (
|
||||
"disconnect" in error_msg.lower()
|
||||
or "Cannot call \"receive\"" in error_msg
|
||||
or "accept" in error_msg.lower() and "not connected" in error_msg.lower()
|
||||
):
|
||||
logger.info(f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}")
|
||||
break
|
||||
else:
|
||||
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||||
await manager.disconnect(conversation_id)
|
||||
cleanup_segment_states(conversation_id)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
|
||||
await manager.disconnect(conversation_id)
|
||||
cleanup_segment_states(conversation_id)
|
||||
finally:
|
||||
await manager.disconnect(conversation_id)
|
||||
cleanup_segment_states(conversation_id)
|
||||
Reference in New Issue
Block a user