fix/various fixes

This commit is contained in:
Kevin
2026-03-20 15:15:35 +08:00
parent 7f57f96c25
commit 7317bf10cd
112 changed files with 3790 additions and 2242 deletions

View File

@@ -1,10 +1,11 @@
"""WebSocket 连接管理器:仅负责连接注册/注销和消息收发"""
from app.core.logging import get_logger
from typing import Dict
from fastapi import HTTPException, WebSocket
from app.core.logging import get_logger
logger = get_logger(__name__)

View File

@@ -2,12 +2,13 @@
import asyncio
import base64
from app.core.logging import get_logger
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from app.core.logging import get_logger
if TYPE_CHECKING:
from app.features.quota.service import QuotaService
@@ -17,7 +18,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents import ConversationAgent, MemoryAgent
from app.agents.chat import ChatOrchestrator
from app.agents.memoir import BackgroundTaskRunner
from app.core.config import settings
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider, get_tts_provider
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import (
@@ -30,14 +33,14 @@ from app.features.conversation.ws.profile_collector import (
get_missing_profile_fields,
)
from app.features.user.models import User
from app.core.config import settings
from app.core.dependencies import get_asr_provider, get_tts_provider
logger = get_logger(__name__)
async def _send_tts_audio(conversation_id: str, text: str) -> None:
"""Synthesize text to speech and send TTS_AUDIO if successful."""
if not settings.enable_tts:
return
try:
tts = get_tts_provider()
audio_bytes = await tts.synthesize(text)
@@ -325,7 +328,7 @@ async def process_audio_segment(
async with AsyncSessionLocal() as db:
conversation = await db.get(Conversation, conversation_id)
user = await db.get(User, user_id)
if not conversation:
if not conversation or conversation.deleted_at is not None:
await manager.send_message(
conversation_id,
{
@@ -562,7 +565,7 @@ async def process_conversation_segments(
配额检查通过注入的 quota_service 完成,不直接 import quota 内部函数。
"""
conversation = await db.get(Conversation, conversation_id)
if not conversation:
if not conversation or conversation.deleted_at is not None:
return
stmt = select(Segment).where(

View File

@@ -4,7 +4,7 @@ WebSocket 路由:实时对话通信
"""
import asyncio
from app.core.logging import get_logger
import base64
import uuid
from datetime import datetime, timezone
@@ -13,8 +13,11 @@ from starlette.websockets import WebSocketState
from app.agents.chat.prompts_profile import format_user_profile_context
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider
from app.core.logging import get_logger
from app.core.security import verify_token
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.service import ConversationService
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import MessageType
from app.features.conversation.ws.pipeline import (
@@ -34,13 +37,9 @@ from app.features.conversation.ws.pipeline import (
)
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
from app.features.conversation.ws.quota_guard import check_ws_quota
from app.features.memoir.state_service import get_or_create_state
from app.features.quota.service import QuotaService
from app.features.user.models import User
import base64
from app.core.dependencies import get_asr_provider
from app.core.redis import redis_service
from app.features.memoir.state_service import get_or_create_state
logger = get_logger(__name__)
@@ -88,6 +87,7 @@ async def websocket_endpoint(
await manager.connect(websocket, conversation_id)
quota_service = QuotaService(db=db)
conversation_service = ConversationService(db=db, quota_service=quota_service)
try:
await manager.send_message(
@@ -127,8 +127,26 @@ async def websocket_endpoint(
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
)
return
if conversation.deleted_at is not None:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "对话已删除"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除"
)
return
history = await redis_service.get_conversation_history(conversation_id)
history = await conversation_service.ensure_redis_history_from_segments(
conversation_id
)
if not history:
missing_profile = get_missing_profile_fields(user)
if missing_profile: