chore/ 删除无用文件
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""WebSocket 连接管理器:仅负责连接注册/注销和消息收发"""
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebSocket 消息类型定义"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
LEGACY_VOICE_SESSION_ID = "legacy"
|
||||
@@ -6,6 +7,7 @@ LEGACY_VOICE_SESSION_ID = "legacy"
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""WebSocket 消息类型"""
|
||||
|
||||
CONNECT = "connect"
|
||||
RECORDING_STARTED = "recording_started"
|
||||
AUDIO_CHUNK = "audio_chunk"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""核心消息处理管道:Agent 调用、ASR 转写、分段有序聚合"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from app.core.logging import get_logger
|
||||
@@ -19,7 +20,10 @@ from app.agents.memoir import BackgroundTaskRunner
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.features.conversation.models import Conversation, Segment
|
||||
from app.features.conversation.ws.connection_manager import manager
|
||||
from app.features.conversation.ws.message_types import LEGACY_VOICE_SESSION_ID, MessageType
|
||||
from app.features.conversation.ws.message_types import (
|
||||
LEGACY_VOICE_SESSION_ID,
|
||||
MessageType,
|
||||
)
|
||||
from app.features.conversation.ws.profile_collector import (
|
||||
apply_extracted_profile,
|
||||
get_filled_profile_fields,
|
||||
@@ -42,15 +46,18 @@ async def _send_tts_audio(conversation_id: str, text: str) -> None:
|
||||
"TTS skipped: synthesize returned empty. Check TTS config in .env"
|
||||
)
|
||||
return
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TTS_AUDIO,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
||||
"format": settings.tts_codec,
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.TTS_AUDIO,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
||||
"format": settings.tts_codec,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
if "PkgExhausted" in err_str:
|
||||
@@ -61,6 +68,7 @@ async def _send_tts_audio(conversation_id: str, text: str) -> None:
|
||||
else:
|
||||
logger.error("TTS synthesize failed: %s", e)
|
||||
|
||||
|
||||
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
|
||||
conversation_agent = ConversationAgent()
|
||||
chat_orchestrator = ChatOrchestrator()
|
||||
@@ -70,6 +78,7 @@ background_runner = BackgroundTaskRunner()
|
||||
|
||||
# ── 分段流状态 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentStreamState:
|
||||
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
|
||||
@@ -136,11 +145,14 @@ def cleanup_segment_states(conversation_id: str) -> None:
|
||||
|
||||
# ── 工具函数 ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime:
|
||||
def _mark_conversation_active(
|
||||
conversation: Conversation, at: Optional[datetime] = None
|
||||
) -> datetime:
|
||||
activity_time = at or _utc_now()
|
||||
conversation.last_message_at = activity_time
|
||||
return activity_time
|
||||
@@ -152,7 +164,9 @@ def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
|
||||
return LEGACY_VOICE_SESSION_ID
|
||||
|
||||
|
||||
def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]:
|
||||
def _voice_session_id_from_client_segment_id(
|
||||
client_segment_id: Optional[str],
|
||||
) -> Optional[str]:
|
||||
if not client_segment_id:
|
||||
return None
|
||||
session_id, separator, _ = client_segment_id.rpartition("-")
|
||||
@@ -171,11 +185,14 @@ def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]
|
||||
prefix = "audio-segment:"
|
||||
if not audio_url or not audio_url.startswith(prefix):
|
||||
return None
|
||||
payload = audio_url[len(prefix):]
|
||||
payload = audio_url[len(prefix) :]
|
||||
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
|
||||
try:
|
||||
if separator:
|
||||
return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw))
|
||||
return (
|
||||
_normalize_voice_session_id(voice_session_id_raw),
|
||||
int(segment_index_raw),
|
||||
)
|
||||
return (LEGACY_VOICE_SESSION_ID, int(payload))
|
||||
except ValueError:
|
||||
return None
|
||||
@@ -201,14 +218,21 @@ async def _find_existing_segment_by_index(
|
||||
segment_index: int,
|
||||
) -> Optional[Segment]:
|
||||
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
|
||||
stmt = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id,
|
||||
Segment.audio_url == segment_audio_url,
|
||||
).order_by(Segment.created_at.desc())
|
||||
stmt = (
|
||||
select(Segment)
|
||||
.where(
|
||||
Segment.conversation_id == conversation_id,
|
||||
Segment.audio_url == segment_audio_url,
|
||||
)
|
||||
.order_by(Segment.created_at.desc())
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
candidates = result.scalars().all()
|
||||
for item in candidates:
|
||||
if item.conversation_id == conversation_id and item.audio_url == segment_audio_url:
|
||||
if (
|
||||
item.conversation_id == conversation_id
|
||||
and item.audio_url == segment_audio_url
|
||||
):
|
||||
return item
|
||||
return None
|
||||
|
||||
@@ -252,16 +276,19 @@ async def _send_segment_transition_feedback(
|
||||
segment_index: int,
|
||||
) -> None:
|
||||
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": LISTENING_FEEDBACK_TEXT,
|
||||
"transition": True,
|
||||
"segment_index": segment_index,
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": LISTENING_FEEDBACK_TEXT,
|
||||
"transition": True,
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
)
|
||||
|
||||
|
||||
async def _delayed_listening_feedback(
|
||||
@@ -281,6 +308,7 @@ async def _delayed_listening_feedback(
|
||||
|
||||
# ── 分段语音异步处理 ────────────────────────────────────────────
|
||||
|
||||
|
||||
async def process_audio_segment(
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
@@ -298,18 +326,24 @@ async def process_audio_segment(
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
user = await db.get(User, user_id)
|
||||
if not conversation:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "对话不存在,分段处理已取消"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "对话不存在,分段处理已取消"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
return
|
||||
if not user:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "用户不存在,分段处理已取消"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "用户不存在,分段处理已取消"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
async with state.lock:
|
||||
@@ -320,14 +354,18 @@ async def process_audio_segment(
|
||||
)
|
||||
|
||||
if should_prime_state:
|
||||
persisted_contiguous_index = await _get_persisted_contiguous_segment_index(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
voice_session_id=voice_session_id,
|
||||
persisted_contiguous_index = (
|
||||
await _get_persisted_contiguous_segment_index(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
voice_session_id=voice_session_id,
|
||||
)
|
||||
)
|
||||
if persisted_contiguous_index >= 0:
|
||||
async with state.lock:
|
||||
state.consumed_index = max(state.consumed_index, persisted_contiguous_index)
|
||||
state.consumed_index = max(
|
||||
state.consumed_index, persisted_contiguous_index
|
||||
)
|
||||
|
||||
try:
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
@@ -336,28 +374,34 @@ async def process_audio_segment(
|
||||
transcript_text = await get_asr_provider().transcribe(
|
||||
audio_bytes, format="m4a"
|
||||
)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": transcript_text or "",
|
||||
"audio_duration": audio_duration,
|
||||
"voice_session_id": voice_session_id,
|
||||
"segment_index": segment_index,
|
||||
"is_last": is_last,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
if _is_transcribe_failure(transcript_text):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"message": f"分段 {segment_index} 转写失败,请重试该片段",
|
||||
"text": transcript_text or "",
|
||||
"audio_duration": audio_duration,
|
||||
"voice_session_id": voice_session_id,
|
||||
"segment_index": segment_index,
|
||||
"is_last": is_last,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
if _is_transcribe_failure(transcript_text):
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": f"分段 {segment_index} 转写失败,请重试该片段",
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
existing_segment = await _find_existing_segment_by_index(
|
||||
@@ -391,7 +435,10 @@ async def process_audio_segment(
|
||||
ready_segments: List[Tuple[int, str, Segment]] = []
|
||||
async with state.lock:
|
||||
state.processed_indices.add(segment_index)
|
||||
state.buffered_transcripts[segment_index] = (transcript_text or "", segment)
|
||||
state.buffered_transcripts[segment_index] = (
|
||||
transcript_text or "",
|
||||
segment,
|
||||
)
|
||||
|
||||
next_index = state.consumed_index + 1
|
||||
while next_index in state.buffered_transcripts:
|
||||
@@ -408,7 +455,8 @@ async def process_audio_segment(
|
||||
segment=ordered_segment,
|
||||
db=db,
|
||||
user=user,
|
||||
user_message_timestamp=ordered_segment.created_at or user_message_timestamp,
|
||||
user_message_timestamp=ordered_segment.created_at
|
||||
or user_message_timestamp,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -416,14 +464,17 @@ async def process_audio_segment(
|
||||
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
|
||||
exc_info=True,
|
||||
)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": f"分段处理失败: {str(e)}",
|
||||
"segment_index": segment_index,
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": f"分段处理失败: {str(e)}",
|
||||
"segment_index": segment_index,
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
)
|
||||
finally:
|
||||
async with state.lock:
|
||||
state.pending_indices.discard(segment_index)
|
||||
@@ -431,6 +482,7 @@ async def process_audio_segment(
|
||||
|
||||
# ── 用户消息处理 ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def process_user_message(
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
@@ -463,12 +515,19 @@ async def process_user_message(
|
||||
await db.commit()
|
||||
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": response_text,
|
||||
"index": i,
|
||||
"total": len(responses),
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
await _send_tts_audio(conversation_id, response_text)
|
||||
if i < len(responses) - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
@@ -477,17 +536,21 @@ async def process_user_message(
|
||||
logger.error(f"处理用户消息失败: {e}", exc_info=True)
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"生成回应失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"生成回应失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
|
||||
|
||||
# ── 对话结束处理 ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def process_conversation_segments(
|
||||
conversation_id: str, db: AsyncSession, quota_service: "QuotaService"
|
||||
):
|
||||
@@ -528,8 +591,11 @@ async def process_conversation_segments(
|
||||
segment_ids = [seg.id for seg in segments]
|
||||
try:
|
||||
from app.tasks.memoir_tasks import process_memoir_segments
|
||||
|
||||
process_memoir_segments.delay(conversation.user_id, segment_ids)
|
||||
logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}")
|
||||
logger.info(
|
||||
f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"提交 Celery 任务失败: {e}")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""用户资料收集:缺失字段检测、提取与应用"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.features.user.models import User
|
||||
@@ -6,7 +7,10 @@ from app.features.user.models import User
|
||||
|
||||
def get_missing_profile_fields(user: User) -> list:
|
||||
"""检查用户缺失的资料字段"""
|
||||
from app.agents.chat.prompts_profile import get_missing_profile_fields as _get_missing
|
||||
from app.agents.chat.prompts_profile import (
|
||||
get_missing_profile_fields as _get_missing,
|
||||
)
|
||||
|
||||
return _get_missing(
|
||||
birth_year=user.birth_year,
|
||||
birth_place=user.birth_place,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebSocket 配额检查:通过注入 QuotaService,不直接 import quota 内部函数。"""
|
||||
|
||||
from app.features.quota.service import QuotaService
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
WebSocket 路由:实时对话通信
|
||||
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from app.core.logging import get_logger
|
||||
import uuid
|
||||
@@ -57,23 +58,31 @@ async def websocket_endpoint(
|
||||
"""
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌"
|
||||
)
|
||||
return
|
||||
|
||||
payload = verify_token(token)
|
||||
if not payload or payload.get("type") != "access":
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌"
|
||||
)
|
||||
return
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容"
|
||||
)
|
||||
return
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
user = await db.get(User, user_id)
|
||||
if not user:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在"
|
||||
)
|
||||
return
|
||||
|
||||
await manager.connect(websocket, conversation_id)
|
||||
@@ -81,12 +90,15 @@ async def websocket_endpoint(
|
||||
quota_service = QuotaService(db=db)
|
||||
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.CONNECT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"status": "connected"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.CONNECT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"status": "connected"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
if not conversation:
|
||||
@@ -101,14 +113,19 @@ async def websocket_endpoint(
|
||||
else:
|
||||
if conversation.user_id != user_id:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "无权访问此对话"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "无权访问此对话"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
|
||||
)
|
||||
return
|
||||
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
@@ -122,12 +139,19 @@ async def websocket_endpoint(
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
for i, text in enumerate(greetings):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": text, "index": i, "total": len(greetings)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": text,
|
||||
"index": i,
|
||||
"total": len(greetings),
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
if i < len(greetings) - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
except Exception as e:
|
||||
@@ -141,18 +165,27 @@ async def websocket_endpoint(
|
||||
grew_up_place=user.grew_up_place,
|
||||
occupation=user.occupation,
|
||||
)
|
||||
opening_messages = await conversation_agent.generate_opening_message(
|
||||
conversation_id=conversation_id,
|
||||
memoir_state=state,
|
||||
user_profile_context=user_profile_context,
|
||||
opening_messages = (
|
||||
await conversation_agent.generate_opening_message(
|
||||
conversation_id=conversation_id,
|
||||
memoir_state=state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
)
|
||||
for i, text in enumerate(opening_messages):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": text, "index": i, "total": len(opening_messages)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": text,
|
||||
"index": i,
|
||||
"total": len(opening_messages),
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
if i < len(opening_messages) - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
except Exception as e:
|
||||
@@ -161,7 +194,9 @@ async def websocket_endpoint(
|
||||
while True:
|
||||
try:
|
||||
if websocket.application_state != WebSocketState.CONNECTED:
|
||||
logger.info(f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}")
|
||||
logger.info(
|
||||
f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}"
|
||||
)
|
||||
break
|
||||
message = await websocket.receive_json()
|
||||
msg_type = message.get("type")
|
||||
@@ -170,13 +205,23 @@ async def websocket_endpoint(
|
||||
text_message = message.get("data", {}).get("text", "")
|
||||
|
||||
if text_message:
|
||||
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
|
||||
can_send, quota_msg = await check_ws_quota(
|
||||
quota_service, user_id, user.subscription_type
|
||||
)
|
||||
if not can_send:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": quota_msg,
|
||||
"code": "QUOTA_EXCEEDED",
|
||||
},
|
||||
"timestamp": datetime.now(
|
||||
timezone.utc
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
segment = Segment(
|
||||
@@ -186,10 +231,14 @@ async def websocket_endpoint(
|
||||
processed=False,
|
||||
)
|
||||
db.add(segment)
|
||||
user_message_timestamp = _mark_conversation_active(conversation)
|
||||
user_message_timestamp = _mark_conversation_active(
|
||||
conversation
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(segment)
|
||||
await background_runner.queue_message(conversation.user_id, segment.id)
|
||||
await background_runner.queue_message(
|
||||
conversation.user_id, segment.id
|
||||
)
|
||||
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
@@ -198,18 +247,24 @@ async def websocket_endpoint(
|
||||
segment=segment,
|
||||
db=db,
|
||||
user=user,
|
||||
user_message_timestamp=segment.created_at or user_message_timestamp,
|
||||
user_message_timestamp=segment.created_at
|
||||
or user_message_timestamp,
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.RECORDING_STARTED:
|
||||
data = message.get("data", {})
|
||||
voice_session_id = _normalize_voice_session_id(data.get("voice_session_id"))
|
||||
voice_session_id = _normalize_voice_session_id(
|
||||
data.get("voice_session_id")
|
||||
)
|
||||
segment_state = get_or_create_segment_state(
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
)
|
||||
async with segment_state.lock:
|
||||
if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done():
|
||||
if (
|
||||
segment_state.listening_feedback_task is not None
|
||||
and not segment_state.listening_feedback_task.done()
|
||||
):
|
||||
continue
|
||||
if segment_state.listening_feedback_sent:
|
||||
continue
|
||||
@@ -227,52 +282,74 @@ async def websocket_endpoint(
|
||||
segment_index_raw = data.get("segment_index")
|
||||
voice_session_id = _normalize_voice_session_id(
|
||||
data.get("voice_session_id")
|
||||
or _voice_session_id_from_client_segment_id(data.get("client_segment_id"))
|
||||
or _voice_session_id_from_client_segment_id(
|
||||
data.get("client_segment_id")
|
||||
)
|
||||
)
|
||||
is_last = bool(data.get("is_last", False))
|
||||
audio_duration = int(data.get("duration", 0) or 0)
|
||||
|
||||
if not audio_base64:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 audio_base64"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 audio_base64"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
if segment_index_raw is None:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 segment_index"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 segment_index"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
segment_index = int(segment_index_raw)
|
||||
except (TypeError, ValueError):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 必须为整数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 必须为整数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
if segment_index < 0:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 不能为负数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "segment_index 不能为负数"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
|
||||
can_send, quota_msg = await check_ws_quota(
|
||||
quota_service, user_id, user.subscription_type
|
||||
)
|
||||
if not can_send:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": quota_msg,
|
||||
"code": "QUOTA_EXCEEDED",
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
segment_state = get_or_create_segment_state(
|
||||
@@ -323,13 +400,23 @@ async def websocket_endpoint(
|
||||
audio_duration = data.get("duration", 0)
|
||||
|
||||
if audio_base64:
|
||||
can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type)
|
||||
can_send, quota_msg = await check_ws_quota(
|
||||
quota_service, user_id, user.subscription_type
|
||||
)
|
||||
if not can_send:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": quota_msg,
|
||||
"code": "QUOTA_EXCEEDED",
|
||||
},
|
||||
"timestamp": datetime.now(
|
||||
timezone.utc
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"收到音频消息,时长: {audio_duration}s")
|
||||
@@ -337,18 +424,25 @@ async def websocket_endpoint(
|
||||
try:
|
||||
asr = get_asr_provider()
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
transcript_text = await asr.transcribe(audio_bytes, "m4a")
|
||||
transcript_text = await asr.transcribe(
|
||||
audio_bytes, "m4a"
|
||||
)
|
||||
logger.info("ASR 转写结果: %s", transcript_text)
|
||||
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": transcript_text,
|
||||
"audio_duration": audio_duration,
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {
|
||||
"text": transcript_text,
|
||||
"audio_duration": audio_duration,
|
||||
},
|
||||
"timestamp": datetime.now(
|
||||
timezone.utc
|
||||
).isoformat(),
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
)
|
||||
|
||||
segment = Segment(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -358,12 +452,18 @@ async def websocket_endpoint(
|
||||
processed=False,
|
||||
)
|
||||
db.add(segment)
|
||||
user_message_timestamp = _mark_conversation_active(conversation)
|
||||
user_message_timestamp = _mark_conversation_active(
|
||||
conversation
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(segment)
|
||||
await background_runner.queue_message(conversation.user_id, segment.id)
|
||||
await background_runner.queue_message(
|
||||
conversation.user_id, segment.id
|
||||
)
|
||||
|
||||
if transcript_text and not transcript_text.startswith("转写失败"):
|
||||
if transcript_text and not transcript_text.startswith(
|
||||
"转写失败"
|
||||
):
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=transcript_text,
|
||||
@@ -371,99 +471,141 @@ async def websocket_endpoint(
|
||||
segment=segment,
|
||||
db=db,
|
||||
user=user,
|
||||
user_message_timestamp=segment.created_at or user_message_timestamp,
|
||||
user_message_timestamp=segment.created_at
|
||||
or user_message_timestamp,
|
||||
)
|
||||
else:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "语音转写失败,请重试或使用文字输入"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": "语音转写失败,请重试或使用文字输入"
|
||||
},
|
||||
"timestamp": datetime.now(
|
||||
timezone.utc
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理音频消息失败: {e}", exc_info=True)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"处理音频消息失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": f"处理音频消息失败: {str(e)}"
|
||||
},
|
||||
"timestamp": datetime.now(
|
||||
timezone.utc
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.TRANSCRIBE_ONLY:
|
||||
data = message.get("data", {})
|
||||
audio_base64 = data.get("audio_base64", "")
|
||||
if not audio_base64:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 audio_base64"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 audio_base64"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
try:
|
||||
asr = get_asr_provider()
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
transcript_text = await asr.transcribe(audio_bytes, "m4a")
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": transcript_text or ""},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": transcript_text or ""},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"仅转写失败: {e}", exc_info=True)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"转写失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"转写失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.END_CONVERSATION:
|
||||
conversation.status = "ended"
|
||||
conversation.ended_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
await process_conversation_segments(conversation_id, db, quota_service)
|
||||
await process_conversation_segments(
|
||||
conversation_id, db, quota_service
|
||||
)
|
||||
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.END_CONVERSATION,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"status": "ended"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.END_CONVERSATION,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"status": "ended"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
break
|
||||
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e)
|
||||
if (
|
||||
"disconnect" in error_msg.lower()
|
||||
or "Cannot call \"receive\"" in error_msg
|
||||
or "accept" in error_msg.lower() and "not connected" in error_msg.lower()
|
||||
or 'Cannot call "receive"' in error_msg
|
||||
or "accept" in error_msg.lower()
|
||||
and "not connected" in error_msg.lower()
|
||||
):
|
||||
logger.info(f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}")
|
||||
logger.info(
|
||||
f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(
|
||||
timezone.utc
|
||||
).isoformat(),
|
||||
},
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||||
logger.info(
|
||||
f"WebSocket 断开连接: conversation_id={conversation_id}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user