Files
life-echo/api/app/features/conversation/ws/router.py
Kevin e4bf0710c7 feat(memory,conversation): 记忆富化/证据包、时间线幂等字段与对话分段全链路
数据库
- 新增迁移 0003:timeline_events.memory_source_id 外键 → memory_sources,便于按 ingest 源做时间线幂等

后端 - 记忆
- 新增 ingest 后 LLM 富化(摘要/事实/时间线),可配置开关与最大字符数
- 新增证据包组装:合并 chunk、摘要、事实、时间线、故事等检索结果;支持空 query 时是否仍带 rolling 等开关
- repo/retriever/service/router/schemas/summarizer/timeline/extractor 等扩展;文档 memory-retrieval.md 更新

后端 - 对话 WS
- 增加 PING/PONG;分段 ASR 日志与空音频处理;转写失败与「无助手回复」错误提示更明确
- 助手多段回复持久化使用统一分隔符,与分段逻辑一致

后端 - Agent
- reply_limits:按 [SPLIT] 与段落拆段,并保证非空 fallback,供 WS 与 TTS 多段下发

后端 - 回忆录任务
- transcript ingest 记录 source_id;任务成功结?
2026-03-27 16:24:43 +08:00

757 lines
35 KiB
Python

"""
WebSocket 路由:实时对话通信
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
"""
import asyncio
import base64
import uuid
from datetime import datetime, timezone
from fastapi import WebSocket, WebSocketDisconnect, status
from starlette.websockets import WebSocketState
from app.agents.chat.prompts_profile import format_user_profile_context
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider
from app.core.logging import get_logger
from app.core.security import verify_token
from app.features.conversation.history_store import ConversationHistoryStore
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.service import ConversationService
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import MessageType
from app.features.conversation.ws.pipeline import (
_delayed_listening_feedback,
_mark_conversation_active,
_voice_session_id_from_client_segment_id,
background_runner,
bump_tts_cancel_epoch,
chat_orchestrator,
cleanup_segment_states,
get_or_create_segment_state,
process_audio_segment,
process_conversation_segments,
process_user_message,
register_segment_task,
)
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
from app.features.conversation.ws.quota_guard import check_ws_quota
from app.features.memoir.state_service import get_or_create_state
from app.features.quota.service import QuotaService
from app.features.user.models import User
logger = get_logger(__name__)
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str,
):
"""
WebSocket 端点:处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
"""
token = websocket.query_params.get("token")
if not token:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌"
)
return
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌"
)
return
user_id = payload.get("sub")
if not user_id:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容"
)
return
async with AsyncSessionLocal() as db:
user = await db.get(User, user_id)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在"
)
return
await manager.connect(websocket, conversation_id)
logger.info(
"WebSocket 已连接 conversation_id={} user_id={}",
conversation_id,
user_id,
)
quota_service = QuotaService(db=db)
conversation_service = ConversationService(db=db, quota_service=quota_service)
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
conversation = await db.get(Conversation, conversation_id)
if not conversation:
conversation = Conversation(
id=conversation_id,
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active",
)
db.add(conversation)
await db.commit()
else:
if conversation.user_id != user_id:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
)
return
if conversation.deleted_at is not None:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "对话已删除"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除"
)
return
history = await conversation_service.ensure_redis_history_from_db(
conversation_id
)
if not history:
missing_profile = get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await chat_orchestrator.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
)
ai_msg_id = await ConversationHistoryStore(
db
).record_ai_only_turn(conversation_id, greetings)
if ai_msg_id:
ng = len(greetings)
for i, text in enumerate(greetings):
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": text,
"index": i,
"total": ng,
"assistant_message_id": ai_msg_id,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
if i < ng - 1:
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
else:
try:
state = await get_or_create_state(user_id, db)
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
opening_messages = (
await chat_orchestrator.generate_opening_message(
conversation_id=conversation_id,
memoir_state=state,
user_profile_context=user_profile_context,
)
)
ai_msg_id = await ConversationHistoryStore(
db
).record_ai_only_turn(conversation_id, opening_messages)
if ai_msg_id:
no = len(opening_messages)
for i, text in enumerate(opening_messages):
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": text,
"index": i,
"total": no,
"assistant_message_id": ai_msg_id,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
if i < no - 1:
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"发送空对话开场白失败: {e}", exc_info=True)
while True:
try:
if websocket.application_state != WebSocketState.CONNECTED:
logger.debug(
"WebSocket 已非连接状态,退出循环: conversation_id={}",
conversation_id,
)
break
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.AUDIO_SEGMENT:
_d = message.get("data") or {}
logger.info(
"WebSocket 收到消息 type={} conversation_id={} "
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
msg_type,
conversation_id,
_d.get("segment_index"),
bool(_d.get("is_last")),
int(_d.get("duration") or 0),
len(_d.get("audio_base64") or ""),
)
elif msg_type is not None:
logger.info(
"WebSocket 收到消息 type={} conversation_id={}",
msg_type,
conversation_id,
)
else:
logger.warning(
"WebSocket 收到缺少 type 的 JSON conversation_id={}",
conversation_id,
)
if msg_type == MessageType.TEXT:
text_message = message.get("data", {}).get("text", "")
if text_message:
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
user_input_text=text_message,
processed=False,
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(
conversation
)
await db.commit()
await db.refresh(segment)
await background_runner.queue_message(
conversation.user_id, segment.id
)
await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
conversation=conversation,
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at
or user_message_timestamp,
)
elif msg_type == MessageType.RECORDING_STARTED:
data = message.get("data", {})
raw_vs = data.get("voice_session_id")
if not raw_vs or not str(raw_vs).strip():
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 voice_session_id"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
voice_session_id = str(raw_vs).strip()
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
async with segment_state.lock:
if (
segment_state.listening_feedback_task is not None
and not segment_state.listening_feedback_task.done()
):
continue
if segment_state.listening_feedback_sent:
continue
delayed_task = asyncio.create_task(
_delayed_listening_feedback(
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
)
segment_state.listening_feedback_task = delayed_task
elif msg_type == MessageType.AUDIO_SEGMENT:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
segment_index_raw = data.get("segment_index")
resolved_vs = data.get("voice_session_id") or (
_voice_session_id_from_client_segment_id(
data.get("client_segment_id")
)
)
if not resolved_vs or not str(resolved_vs).strip():
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "缺少 voice_session_id 或有效的 client_segment_id"
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
voice_session_id = str(resolved_vs).strip()
is_last = bool(data.get("is_last", False))
audio_duration = int(data.get("duration", 0) or 0)
if not audio_base64:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index_raw is None:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 segment_index"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
segment_index = int(segment_index_raw)
except (TypeError, ValueError):
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 必须为整数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index < 0:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 不能为负数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
should_process = False
async with segment_state.lock:
already_seen = (
segment_index in segment_state.pending_indices
or segment_index in segment_state.processed_indices
or segment_index <= segment_state.consumed_index
)
if not already_seen:
segment_state.pending_indices.add(segment_index)
should_process = True
if not should_process:
logger.debug(
"收到重复分段,跳过: conversation_id={} voice_session_id={} "
"segment_index={} audio_b64_len={} duration={}",
conversation_id,
voice_session_id,
segment_index,
len(audio_base64 or ""),
audio_duration,
)
continue
if is_last:
async with segment_state.lock:
t = segment_state.listening_feedback_task
segment_state.listening_feedback_task = None
if t is not None and not t.done():
t.cancel()
task = asyncio.create_task(
process_audio_segment(
conversation_id=conversation_id,
user_id=user_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
audio_base64=audio_base64,
audio_duration=audio_duration,
is_last=is_last,
)
)
register_segment_task(conversation_id, voice_session_id, task)
elif msg_type == MessageType.AUDIO_MESSAGE:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
audio_duration = data.get("duration", 0)
if audio_base64:
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
logger.debug(
"收到音频消息: conversation_id={} duration_s={}",
conversation_id,
audio_duration,
)
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
asr_text = await asr.transcribe(audio_bytes, "m4a")
logger.debug(
"ASR 转写完成: conversation_id={} chars={}",
conversation_id,
len(asr_text or ""),
)
logger.debug(
"ASR 转写全文: conversation_id={} text={}",
conversation_id,
asr_text,
)
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": asr_text,
"audio_duration": audio_duration,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
try:
ads = int(audio_duration)
except (TypeError, ValueError):
ads = 0
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
user_input_text=asr_text,
audio_url=f"audio:{audio_duration}s",
audio_duration_seconds=ads if ads > 0 else None,
processed=False,
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(
conversation
)
await db.commit()
await db.refresh(segment)
await background_runner.queue_message(
conversation.user_id, segment.id
)
if asr_text and not asr_text.startswith("转写失败"):
await process_user_message(
conversation_id=conversation_id,
user_message=asr_text,
conversation=conversation,
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at
or user_message_timestamp,
)
else:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音转写失败,请重试或使用文字输入"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as e:
logger.error(f"处理音频消息失败: {e}", exc_info=True)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": f"处理音频消息失败: {str(e)}"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
elif msg_type == MessageType.TRANSCRIBE_ONLY:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
if not audio_base64:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
asr_text = await asr.transcribe(audio_bytes, "m4a")
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": asr_text or ""},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as e:
logger.error(f"仅转写失败: {e}", exc_info=True)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": f"转写失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
elif msg_type == MessageType.TTS_CANCEL:
bump_tts_cancel_epoch(conversation_id)
elif msg_type == MessageType.END_CONVERSATION:
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
await db.commit()
await process_conversation_segments(
conversation_id, db, quota_service
)
await manager.send_message(
conversation_id,
{
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
break
elif msg_type == MessageType.PING:
await manager.send_message(
conversation_id,
{
"type": MessageType.PONG,
"conversation_id": conversation_id,
"data": {},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
else:
if msg_type is not None:
logger.warning(
"WebSocket 未识别的消息 type={} conversation_id={}",
msg_type,
conversation_id,
)
except RuntimeError as e:
error_msg = str(e)
if (
"disconnect" in error_msg.lower()
or 'Cannot call "receive"' in error_msg
or "accept" in error_msg.lower()
and "not connected" in error_msg.lower()
):
logger.debug(
"WebSocket 连接已断开或未就绪: conversation_id={} error={}",
conversation_id,
error_msg,
)
break
else:
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect as disc:
logger.info(
"WebSocket 断开连接(收消息循环): conversation_id={} code={}",
conversation_id,
getattr(disc, "code", None),
)
break
except Exception as e:
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect as disc:
logger.info(
"WebSocket 断开连接: conversation_id={} code={}",
conversation_id,
getattr(disc, "code", None),
)
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)
except Exception as e:
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)
finally:
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)