chore/ 删除无用文件

This commit is contained in:
Kevin
2026-03-19 14:36:14 +08:00
parent 2f60858c9c
commit c6e07ce5ca
135 changed files with 2111 additions and 4510 deletions

View File

@@ -2,6 +2,7 @@
WebSocket 路由:实时对话通信
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
"""
import asyncio
from app.core.logging import get_logger
import uuid
@@ -57,23 +58,31 @@ async def websocket_endpoint(
"""
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌"
)
return
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌"
)
return
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容"
)
return
async with AsyncSessionLocal() as db:
user = await db.get(User, user_id)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在"
)
return
await manager.connect(websocket, conversation_id)
@@ -81,12 +90,15 @@ async def websocket_endpoint(
quota_service = QuotaService(db=db)
try:
await manager.send_message(conversation_id, {
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
conversation = await db.get(Conversation, conversation_id)
if not conversation:
@@ -101,14 +113,19 @@ async def websocket_endpoint(
else:
if conversation.user_id != user_id:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
)
return
history = await redis_service.get_conversation_history(conversation_id)
@@ -122,12 +139,19 @@ async def websocket_endpoint(
nickname=user.nickname or "",
)
for i, text in enumerate(greetings):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": text, "index": i, "total": len(greetings)},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": text,
"index": i,
"total": len(greetings),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
if i < len(greetings) - 1:
await asyncio.sleep(0.5)
except Exception as e:
@@ -141,18 +165,27 @@ async def websocket_endpoint(
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
opening_messages = await conversation_agent.generate_opening_message(
conversation_id=conversation_id,
memoir_state=state,
user_profile_context=user_profile_context,
opening_messages = (
await conversation_agent.generate_opening_message(
conversation_id=conversation_id,
memoir_state=state,
user_profile_context=user_profile_context,
)
)
for i, text in enumerate(opening_messages):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": text, "index": i, "total": len(opening_messages)},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": text,
"index": i,
"total": len(opening_messages),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
if i < len(opening_messages) - 1:
await asyncio.sleep(0.5)
except Exception as e:
@@ -161,7 +194,9 @@ async def websocket_endpoint(
while True:
try:
if websocket.application_state != WebSocketState.CONNECTED:
logger.info(f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}")
logger.info(
f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}"
)
break
message = await websocket.receive_json()
msg_type = message.get("type")
@@ -170,13 +205,23 @@ async def websocket_endpoint(
text_message = message.get("data", {}).get("text", "")
if text_message:
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
segment = Segment(
@@ -186,10 +231,14 @@ async def websocket_endpoint(
processed=False,
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(conversation)
user_message_timestamp = _mark_conversation_active(
conversation
)
await db.commit()
await db.refresh(segment)
await background_runner.queue_message(conversation.user_id, segment.id)
await background_runner.queue_message(
conversation.user_id, segment.id
)
await process_user_message(
conversation_id=conversation_id,
@@ -198,18 +247,24 @@ async def websocket_endpoint(
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at or user_message_timestamp,
user_message_timestamp=segment.created_at
or user_message_timestamp,
)
elif msg_type == MessageType.RECORDING_STARTED:
data = message.get("data", {})
voice_session_id = _normalize_voice_session_id(data.get("voice_session_id"))
voice_session_id = _normalize_voice_session_id(
data.get("voice_session_id")
)
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
async with segment_state.lock:
if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done():
if (
segment_state.listening_feedback_task is not None
and not segment_state.listening_feedback_task.done()
):
continue
if segment_state.listening_feedback_sent:
continue
@@ -227,52 +282,74 @@ async def websocket_endpoint(
segment_index_raw = data.get("segment_index")
voice_session_id = _normalize_voice_session_id(
data.get("voice_session_id")
or _voice_session_id_from_client_segment_id(data.get("client_segment_id"))
or _voice_session_id_from_client_segment_id(
data.get("client_segment_id")
)
)
is_last = bool(data.get("is_last", False))
audio_duration = int(data.get("duration", 0) or 0)
if not audio_base64:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index_raw is None:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 segment_index"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 segment_index"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
segment_index = int(segment_index_raw)
except (TypeError, ValueError):
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "segment_index 必须为整数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 必须为整数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index < 0:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "segment_index 不能为负数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 不能为负数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
segment_state = get_or_create_segment_state(
@@ -323,13 +400,23 @@ async def websocket_endpoint(
audio_duration = data.get("duration", 0)
if audio_base64:
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
logger.info(f"收到音频消息,时长: {audio_duration}s")
@@ -337,18 +424,25 @@ async def websocket_endpoint(
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
transcript_text = await asr.transcribe(audio_bytes, "m4a")
transcript_text = await asr.transcribe(
audio_bytes, "m4a"
)
logger.info("ASR 转写结果: %s", transcript_text)
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text,
"audio_duration": audio_duration,
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text,
"audio_duration": audio_duration,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
)
segment = Segment(
id=str(uuid.uuid4()),
@@ -358,12 +452,18 @@ async def websocket_endpoint(
processed=False,
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(conversation)
user_message_timestamp = _mark_conversation_active(
conversation
)
await db.commit()
await db.refresh(segment)
await background_runner.queue_message(conversation.user_id, segment.id)
await background_runner.queue_message(
conversation.user_id, segment.id
)
if transcript_text and not transcript_text.startswith("转写失败"):
if transcript_text and not transcript_text.startswith(
"转写失败"
):
await process_user_message(
conversation_id=conversation_id,
user_message=transcript_text,
@@ -371,99 +471,141 @@ async def websocket_endpoint(
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at or user_message_timestamp,
user_message_timestamp=segment.created_at
or user_message_timestamp,
)
else:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "语音转写失败,请重试或使用文字输入"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音转写失败,请重试或使用文字输入"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as e:
logger.error(f"处理音频消息失败: {e}", exc_info=True)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"处理音频消息失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": f"处理音频消息失败: {str(e)}"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
elif msg_type == MessageType.TRANSCRIBE_ONLY:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
if not audio_base64:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
transcript_text = await asr.transcribe(audio_bytes, "m4a")
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": transcript_text or ""},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": transcript_text or ""},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as e:
logger.error(f"仅转写失败: {e}", exc_info=True)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"转写失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": f"转写失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
elif msg_type == MessageType.END_CONVERSATION:
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
await db.commit()
await process_conversation_segments(conversation_id, db, quota_service)
await process_conversation_segments(
conversation_id, db, quota_service
)
await manager.send_message(conversation_id, {
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
break
except RuntimeError as e:
error_msg = str(e)
if (
"disconnect" in error_msg.lower()
or "Cannot call \"receive\"" in error_msg
or "accept" in error_msg.lower() and "not connected" in error_msg.lower()
or 'Cannot call "receive"' in error_msg
or "accept" in error_msg.lower()
and "not connected" in error_msg.lower()
):
logger.info(f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}")
logger.info(
f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}"
)
break
else:
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
logger.info(
f"WebSocket 断开连接: conversation_id={conversation_id}"
)
break
except Exception as e:
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break