Files
life-echo/api/app/features/conversation/ws/router.py
Kevin ac49bc7f23 feat(eval): memoir A/B chapter judging and eval-web parity with dialogue
- Judge baseline excerpt and library chapter separately; build_memoir_compare_summary for gate, nine-dim and leaf deltas.

- Memoir SSE chapter payload: baseline_judge, compare_summary, baseline_judge_error.

- MemoirJudgeOutput: loose score coercion and post-validate clamp; memoir judge prompt caps from settings.

- app-eval-web: two-column MemoirScoreCard layout, MemoirCompareSummary, chapter blocks and CSS.

- Add memoir_compare_summary, log_events, celery_log_context, memoir_pipeline_progress; tests and migration 0014.

- Misc: memory/evidence and enrichment paths, task/orchestrator updates, internal-eval docs, env examples.
2026-04-10 10:25:15 +08:00

760 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WebSocket 路由:实时对话通信
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
"""
import asyncio
import base64
from datetime import datetime, timezone
from fastapi import WebSocket, WebSocketDisconnect, status
from starlette.websockets import WebSocketState
from app.agents.chat.background_voice import infer_background_voice
from app.agents.chat.prompts_profile import format_user_profile_context
from app.agents.stage_constants import STAGE_TO_ORDER
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider
from app.core.logging import get_logger
from app.core.security import verify_token
from app.features.conversation.history_store import ConversationHistoryStore
from app.features.conversation.service import ConversationService
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import MessageType
from app.features.conversation.ws.pipeline import (
_delayed_listening_feedback,
_voice_session_id_from_client_segment_id,
background_runner,
bump_tts_cancel_epoch,
chat_orchestrator,
cleanup_segment_states,
get_or_create_segment_state,
process_audio_segment,
process_conversation_segments,
process_user_message,
register_segment_task,
)
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
from app.features.conversation.ws.quota_guard import check_ws_quota
from app.features.memoir.state_service import get_or_create_state
from app.features.quota.service import QuotaService
from app.features.user.models import User
logger = get_logger(__name__)
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str,
):
"""
WebSocket 端点:处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
"""
token = websocket.query_params.get("token")
if not token:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌"
)
return
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌"
)
return
user_id = payload.get("sub")
if not user_id:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容"
)
return
async with AsyncSessionLocal() as db:
user = await db.get(User, user_id)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在"
)
return
await manager.connect(websocket, conversation_id)
logger.info(
"WebSocket 已连接 conversation_id={} user_id={}",
conversation_id,
user_id,
)
quota_service = QuotaService(db=db)
conversation_service = ConversationService(db=db, quota_service=quota_service)
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
conversation, ws_conn_err = await conversation_service.ensure_ws_connection(
conversation_id, user_id
)
if ws_conn_err == "forbidden":
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
)
return
if ws_conn_err == "deleted":
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "对话已删除"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除"
)
return
# 冷启动对齐 conversation_stage 与 MemoirState.current_stage
# 若对话行已有更靠前的人生阶段STAGE_TO_ORDER 更大),不覆盖以免回退。
memoir_state = await get_or_create_state(user_id, db)
ms = (memoir_state.current_stage or "").strip()
cs = (conversation.conversation_stage or "").strip()
if ms:
if not cs:
conversation.conversation_stage = ms
elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1):
conversation.conversation_stage = ms
await db.commit()
await db.refresh(conversation)
history = await conversation_service.ensure_redis_history_from_db(
conversation_id
)
if not history:
missing_profile = get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await chat_orchestrator.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
)
ai_msg_id = await ConversationHistoryStore(
db
).record_ai_only_turn(conversation_id, greetings)
if ai_msg_id:
ng = len(greetings)
for i, text in enumerate(greetings):
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": text,
"index": i,
"total": ng,
"assistant_message_id": ai_msg_id,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
if i < ng - 1:
await asyncio.sleep(0.5)
except Exception as e:
logger.exception("发送资料收集开场白失败: {}", e)
else:
try:
state = memoir_state
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
era_place = (user.grew_up_place or user.birth_place or "") or ""
opening_messages = (
await chat_orchestrator.generate_opening_message(
conversation_id=conversation_id,
memoir_state=state,
user_profile_context=user_profile_context,
background_voice=infer_background_voice(
user.occupation
),
occupation=user.occupation or "",
profile_birth_year=user.birth_year,
profile_era_place=era_place,
)
)
ai_msg_id = await ConversationHistoryStore(
db
).record_ai_only_turn(conversation_id, opening_messages)
if ai_msg_id:
no = len(opening_messages)
for i, text in enumerate(opening_messages):
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": text,
"index": i,
"total": no,
"assistant_message_id": ai_msg_id,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
if i < no - 1:
await asyncio.sleep(0.5)
except Exception as e:
logger.exception("发送空对话开场白失败: {}", e)
while True:
try:
if websocket.application_state != WebSocketState.CONNECTED:
logger.debug(
"WebSocket 已非连接状态,退出循环: conversation_id={}",
conversation_id,
)
break
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.AUDIO_SEGMENT:
_d = message.get("data") or {}
logger.info(
"WebSocket 收到消息 type={} conversation_id={} "
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
msg_type,
conversation_id,
_d.get("segment_index"),
bool(_d.get("is_last")),
int(_d.get("duration") or 0),
len(_d.get("audio_base64") or ""),
)
elif msg_type is not None:
logger.info(
"WebSocket 收到消息 type={} conversation_id={}",
msg_type,
conversation_id,
)
else:
logger.warning(
"WebSocket 收到缺少 type 的 JSON conversation_id={}",
conversation_id,
)
if msg_type == MessageType.TEXT:
text_message = message.get("data", {}).get("text", "")
if text_message:
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
segment = await conversation_service.create_user_segment(
conversation,
user_id,
text_message,
)
user_message_timestamp = conversation.last_message_at
await background_runner.queue_message(
conversation.user_id,
segment.id,
text_char_count=len(text_message.strip()),
)
await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
conversation=conversation,
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at
or user_message_timestamp,
)
elif msg_type == MessageType.RECORDING_STARTED:
data = message.get("data", {})
raw_vs = data.get("voice_session_id")
if not raw_vs or not str(raw_vs).strip():
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 voice_session_id"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
voice_session_id = str(raw_vs).strip()
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
async with segment_state.lock:
if (
segment_state.listening_feedback_task is not None
and not segment_state.listening_feedback_task.done()
):
continue
if segment_state.listening_feedback_sent:
continue
delayed_task = asyncio.create_task(
_delayed_listening_feedback(
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
)
segment_state.listening_feedback_task = delayed_task
elif msg_type == MessageType.AUDIO_SEGMENT:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
segment_index_raw = data.get("segment_index")
resolved_vs = data.get("voice_session_id") or (
_voice_session_id_from_client_segment_id(
data.get("client_segment_id")
)
)
if not resolved_vs or not str(resolved_vs).strip():
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "缺少 voice_session_id 或有效的 client_segment_id"
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
voice_session_id = str(resolved_vs).strip()
is_last = bool(data.get("is_last", False))
audio_duration = int(data.get("duration", 0) or 0)
if not audio_base64:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index_raw is None:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 segment_index"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
segment_index = int(segment_index_raw)
except (TypeError, ValueError):
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 必须为整数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index < 0:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 不能为负数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
should_process = False
async with segment_state.lock:
already_seen = (
segment_index in segment_state.pending_indices
or segment_index in segment_state.processed_indices
or segment_index <= segment_state.consumed_index
)
if not already_seen:
segment_state.pending_indices.add(segment_index)
should_process = True
if not should_process:
logger.debug(
"收到重复分段,跳过: conversation_id={} voice_session_id={} "
"segment_index={} audio_b64_len={} duration={}",
conversation_id,
voice_session_id,
segment_index,
len(audio_base64 or ""),
audio_duration,
)
continue
if is_last:
async with segment_state.lock:
t = segment_state.listening_feedback_task
segment_state.listening_feedback_task = None
if t is not None and not t.done():
t.cancel()
task = asyncio.create_task(
process_audio_segment(
conversation_id=conversation_id,
user_id=user_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
audio_base64=audio_base64,
audio_duration=audio_duration,
is_last=is_last,
)
)
register_segment_task(conversation_id, voice_session_id, task)
elif msg_type == MessageType.AUDIO_MESSAGE:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
audio_duration = data.get("duration", 0)
if audio_base64:
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
logger.debug(
"收到音频消息: conversation_id={} duration_s={}",
conversation_id,
audio_duration,
)
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
asr_text = await asr.transcribe(audio_bytes, "m4a")
logger.debug(
"ASR 转写完成: conversation_id={} chars={}",
conversation_id,
len(asr_text or ""),
)
logger.debug(
"ASR 转写全文: conversation_id={} text={}",
conversation_id,
asr_text,
)
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": asr_text,
"audio_duration": audio_duration,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
try:
ads = int(audio_duration)
except (TypeError, ValueError):
ads = 0
segment = (
await conversation_service.create_user_segment(
conversation,
user_id,
asr_text,
audio_url=f"audio:{audio_duration}s",
audio_duration_seconds=ads if ads > 0 else None,
)
)
user_message_timestamp = conversation.last_message_at
await background_runner.queue_message(
conversation.user_id,
segment.id,
text_char_count=len((asr_text or "").strip()),
)
if asr_text and not asr_text.startswith("转写失败"):
await process_user_message(
conversation_id=conversation_id,
user_message=asr_text,
conversation=conversation,
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at
or user_message_timestamp,
)
else:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音转写失败,请重试或使用文字输入"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as e:
logger.exception("处理音频消息失败: {}", e)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音处理失败,请重试或使用文字输入"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
elif msg_type == MessageType.TRANSCRIBE_ONLY:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
if not audio_base64:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
asr_text = await asr.transcribe(audio_bytes, "m4a")
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": asr_text or ""},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as e:
logger.exception("仅转写失败: {}", e)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "语音转写失败,请重试"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
elif msg_type == MessageType.TTS_CANCEL:
bump_tts_cancel_epoch(conversation_id)
elif msg_type == MessageType.END_CONVERSATION:
await conversation_service.end(conversation_id, user_id)
await process_conversation_segments(
conversation_id, db, quota_service
)
await manager.send_message(
conversation_id,
{
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
break
elif msg_type == MessageType.PING:
await manager.send_message(
conversation_id,
{
"type": MessageType.PONG,
"conversation_id": conversation_id,
"data": {},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
else:
if msg_type is not None:
logger.warning(
"WebSocket 未识别的消息 type={} conversation_id={}",
msg_type,
conversation_id,
)
except RuntimeError as e:
error_msg = str(e)
if (
"disconnect" in error_msg.lower()
or 'Cannot call "receive"' in error_msg
or "accept" in error_msg.lower()
and "not connected" in error_msg.lower()
):
logger.debug(
"WebSocket 连接已断开或未就绪: conversation_id={} error={}",
conversation_id,
error_msg,
)
break
else:
logger.exception("处理消息时发生 RuntimeError: {}", e)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "处理失败,请重试"},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as send_error:
logger.warning("发送错误消息失败: {}", send_error)
break
except WebSocketDisconnect as disc:
logger.info(
"WebSocket 断开连接(收消息循环): conversation_id={} code={}",
conversation_id,
getattr(disc, "code", None),
)
break
except Exception as e:
logger.exception("处理消息时发生错误: {}", e)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "处理失败,请重试"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as send_error:
logger.warning("发送错误消息失败: {}", send_error)
break
except WebSocketDisconnect as disc:
logger.info(
"WebSocket 断开连接: conversation_id={} code={}",
conversation_id,
getattr(disc, "code", None),
)
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)
except Exception as e:
logger.exception("WebSocket 端点发生错误: {}", e)
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)
finally:
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)