fix/various fixes
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
"""WebSocket 连接管理器:仅负责连接注册/注销和消息收发"""
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import HTTPException, WebSocket
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from app.core.logging import get_logger
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.features.quota.service import QuotaService
|
||||
|
||||
@@ -17,7 +18,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.agents import ConversationAgent, MemoryAgent
|
||||
from app.agents.chat import ChatOrchestrator
|
||||
from app.agents.memoir import BackgroundTaskRunner
|
||||
from app.core.config import settings
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.core.dependencies import get_asr_provider, get_tts_provider
|
||||
from app.features.conversation.models import Conversation, Segment
|
||||
from app.features.conversation.ws.connection_manager import manager
|
||||
from app.features.conversation.ws.message_types import (
|
||||
@@ -30,14 +33,14 @@ from app.features.conversation.ws.profile_collector import (
|
||||
get_missing_profile_fields,
|
||||
)
|
||||
from app.features.user.models import User
|
||||
from app.core.config import settings
|
||||
from app.core.dependencies import get_asr_provider, get_tts_provider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def _send_tts_audio(conversation_id: str, text: str) -> None:
|
||||
"""Synthesize text to speech and send TTS_AUDIO if successful."""
|
||||
if not settings.enable_tts:
|
||||
return
|
||||
try:
|
||||
tts = get_tts_provider()
|
||||
audio_bytes = await tts.synthesize(text)
|
||||
@@ -325,7 +328,7 @@ async def process_audio_segment(
|
||||
async with AsyncSessionLocal() as db:
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
user = await db.get(User, user_id)
|
||||
if not conversation:
|
||||
if not conversation or conversation.deleted_at is not None:
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
@@ -562,7 +565,7 @@ async def process_conversation_segments(
|
||||
配额检查通过注入的 quota_service 完成,不直接 import quota 内部函数。
|
||||
"""
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
if not conversation:
|
||||
if not conversation or conversation.deleted_at is not None:
|
||||
return
|
||||
|
||||
stmt = select(Segment).where(
|
||||
|
||||
@@ -4,7 +4,7 @@ WebSocket 路由:实时对话通信
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from app.core.logging import get_logger
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -13,8 +13,11 @@ from starlette.websockets import WebSocketState
|
||||
|
||||
from app.agents.chat.prompts_profile import format_user_profile_context
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.core.dependencies import get_asr_provider
|
||||
from app.core.logging import get_logger
|
||||
from app.core.security import verify_token
|
||||
from app.features.conversation.models import Conversation, Segment
|
||||
from app.features.conversation.service import ConversationService
|
||||
from app.features.conversation.ws.connection_manager import manager
|
||||
from app.features.conversation.ws.message_types import MessageType
|
||||
from app.features.conversation.ws.pipeline import (
|
||||
@@ -34,13 +37,9 @@ from app.features.conversation.ws.pipeline import (
|
||||
)
|
||||
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
|
||||
from app.features.conversation.ws.quota_guard import check_ws_quota
|
||||
from app.features.memoir.state_service import get_or_create_state
|
||||
from app.features.quota.service import QuotaService
|
||||
from app.features.user.models import User
|
||||
import base64
|
||||
|
||||
from app.core.dependencies import get_asr_provider
|
||||
from app.core.redis import redis_service
|
||||
from app.features.memoir.state_service import get_or_create_state
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -88,6 +87,7 @@ async def websocket_endpoint(
|
||||
await manager.connect(websocket, conversation_id)
|
||||
|
||||
quota_service = QuotaService(db=db)
|
||||
conversation_service = ConversationService(db=db, quota_service=quota_service)
|
||||
|
||||
try:
|
||||
await manager.send_message(
|
||||
@@ -127,8 +127,26 @@ async def websocket_endpoint(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
|
||||
)
|
||||
return
|
||||
if conversation.deleted_at is not None:
|
||||
try:
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "对话已删除"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除"
|
||||
)
|
||||
return
|
||||
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
history = await conversation_service.ensure_redis_history_from_segments(
|
||||
conversation_id
|
||||
)
|
||||
if not history:
|
||||
missing_profile = get_missing_profile_fields(user)
|
||||
if missing_profile:
|
||||
|
||||
Reference in New Issue
Block a user