Files
life-echo/api/app/features/conversation/ws/router.py
Claude 55cfbc7f80 feat: agent proactively re-engages users on returning sessions
Two complementary changes to reduce conversation cold-start friction:

A. Returning-user re-greeting (backend)
- When WS reconnects to a non-empty conversation and last_message_at is older
  than chat_re_greeting_idle_hours (default 6h), the agent emits a warm
  continuation message that references prior history instead of staying silent.
- Self-debouncing: the AI message updates last_message_at, so reconnects
  within the window will not re-trigger.
- Skipped while profile collection is still pending.

D. Topic suggestion chips (backend + Expo)
- New WS message type topic_suggestions carries 3-4 quick-start chips derived
  from the current memoir stage's empty slots (deterministic, no extra LLM
  cost). Sent alongside opening / re-greeting / resume.
- Expo chat screen renders a horizontally-scrollable chip row above the input
  bar; tapping a chip sends the chip's text as a user message and clears the
  row. Sending any text/voice also clears the chips.
2026-05-07 15:39:33 +00:00

865 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WebSocket 路由:实时对话通信
仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块
"""
import asyncio
import base64
from datetime import datetime, timezone
from fastapi import WebSocket, WebSocketDisconnect, status
from starlette.websockets import WebSocketState
from app.agents.chat.background_voice import infer_background_voice
from app.agents.chat.prompts_conversation import build_topic_chips
from app.agents.chat.prompts_profile import format_user_profile_context
from app.agents.stage_constants import STAGE_TO_ORDER
from app.agents.state_schema import (
interview_control_state,
narrative_coverage_state,
)
from app.core.config import settings
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider
from app.core.logging import get_logger
from app.core.security import verify_token
from app.features.conversation.history_store import ConversationHistoryStore
from app.features.conversation.service import ConversationService
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import MessageType
from app.features.conversation.ws.pipeline import (
_delayed_listening_feedback,
_voice_session_id_from_client_segment_id,
bump_tts_cancel_epoch,
chat_orchestrator,
cleanup_segment_states,
get_or_create_segment_state,
memoir_ingest_scheduler,
process_audio_segment,
process_conversation_segments,
process_user_message,
register_segment_task,
)
from app.features.conversation.ws.profile_collector import get_missing_profile_fields
from app.features.conversation.ws.quota_guard import check_ws_quota
from app.features.memoir.state_service import get_or_create_state
from app.features.quota.service import QuotaService
from app.features.user.models import User
logger = get_logger(__name__)
def _idle_hours_since(ts) -> float | None:
"""计算距 ts 的小时数ts 为 None 或非 datetime 时返回 None。"""
if ts is None:
return None
if not isinstance(ts, datetime):
return None
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
delta = datetime.now(timezone.utc) - ts
return max(0.0, delta.total_seconds() / 3600.0)
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str,
):
"""
WebSocket 端点:处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
"""
token = websocket.query_params.get("token")
if not token:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌"
)
return
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌"
)
return
user_id = payload.get("sub")
if not user_id:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容"
)
return
async with AsyncSessionLocal() as db:
user = await db.get(User, user_id)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在"
)
return
await manager.connect(websocket, conversation_id)
logger.info(
"WebSocket 已连接 conversation_id={} user_id={}",
conversation_id,
user_id,
)
quota_service = QuotaService(db=db)
conversation_service = ConversationService(db=db, quota_service=quota_service)
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
conversation, ws_conn_err = await conversation_service.ensure_ws_connection(
conversation_id, user_id
)
if ws_conn_err == "forbidden":
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话"
)
return
if ws_conn_err == "deleted":
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "对话已删除"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception:
pass
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除"
)
return
# 冷启动对齐 conversation_stage 与 MemoirState.current_stage
# 若对话行已有更靠前的人生阶段STAGE_TO_ORDER 更大),不覆盖以免回退。
memoir_state = await get_or_create_state(user_id, db)
ms = (memoir_state.current_stage or "").strip()
cs = (conversation.conversation_stage or "").strip()
if ms:
if not cs:
conversation.conversation_stage = ms
elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1):
conversation.conversation_stage = ms
await db.commit()
await db.refresh(conversation)
history = await conversation_service.ensure_redis_history_from_db(
conversation_id
)
async def _stream_ai_only_messages(
texts: list[str], log_label: str
) -> None:
"""统一:把一组 AI 消息落库并按 [SPLIT] 分段下发。"""
if not texts:
return
ai_msg_id = await ConversationHistoryStore(db).record_ai_only_turn(
conversation_id, texts
)
if not ai_msg_id:
return
total_n = len(texts)
for i, text in enumerate(texts):
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": text,
"index": i,
"total": total_n,
"assistant_message_id": ai_msg_id,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
if i < total_n - 1:
await asyncio.sleep(0.5)
logger.info(
"event=ws_auto_ai_sent label={} conversation_id={} segments={}",
log_label,
conversation_id,
total_n,
)
async def _maybe_send_topic_chips(reason: str) -> None:
"""根据当前阶段空 slot 生成 quick-start 话题 chips失败静默。"""
if not settings.chat_topic_chips_enabled:
return
# 资料未齐时不送 chipsprofile 收集走另一条流程chips 反而噪音
if get_missing_profile_fields(user):
return
try:
narrative_state = narrative_coverage_state(memoir_state)
control_state = interview_control_state(memoir_state)
empty_slots = control_state.prompt_empty_slots_for_stage(
narrative_state, memoir_state.current_stage
)
chips = build_topic_chips(
memoir_state.current_stage,
empty_slots,
max_chips=settings.chat_topic_chips_max,
)
if not chips:
return
await manager.send_message(
conversation_id,
{
"type": MessageType.TOPIC_SUGGESTIONS,
"conversation_id": conversation_id,
"data": {
"reason": reason,
"stage": memoir_state.current_stage,
"suggestions": chips,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
logger.info(
"event=ws_topic_chips_sent reason={} conversation_id={} "
"stage={} count={}",
reason,
conversation_id,
memoir_state.current_stage,
len(chips),
)
except Exception as e:
logger.warning("发送话题 chips 失败: {}", e)
if not history:
missing_profile = get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await chat_orchestrator.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
)
await _stream_ai_only_messages(
greetings, log_label="profile_greeting"
)
except Exception as e:
logger.exception("发送资料收集开场白失败: {}", e)
else:
try:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
era_place = (user.grew_up_place or user.birth_place or "") or ""
opening_messages = (
await chat_orchestrator.generate_opening_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
background_voice=infer_background_voice(
user.occupation
),
occupation=user.occupation or "",
profile_birth_year=user.birth_year,
profile_era_place=era_place,
)
)
await _stream_ai_only_messages(
opening_messages, log_label="opening"
)
await _maybe_send_topic_chips(reason="opening")
except Exception as e:
logger.exception("发送空对话开场白失败: {}", e)
else:
# 历史非空:判断是否需要回访问候(距上次消息超过阈值)
idle_hours = _idle_hours_since(conversation.last_message_at)
threshold = float(settings.chat_re_greeting_idle_hours)
if (
settings.chat_re_greeting_enabled
and not get_missing_profile_fields(user)
and idle_hours is not None
and idle_hours >= threshold
):
try:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
era_place = (user.grew_up_place or user.birth_place or "") or ""
re_greetings = (
await chat_orchestrator.generate_re_greeting_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
idle_hours=idle_hours,
user_profile_context=user_profile_context,
background_voice=infer_background_voice(
user.occupation
),
occupation=user.occupation or "",
profile_birth_year=user.birth_year,
profile_era_place=era_place,
)
)
await _stream_ai_only_messages(
re_greetings, log_label="re_greeting"
)
logger.info(
"event=ws_re_greeting_emitted conversation_id={} "
"idle_hours={:.2f} threshold={:.2f}",
conversation_id,
idle_hours,
threshold,
)
await _maybe_send_topic_chips(reason="re_greeting")
except Exception as e:
logger.exception("发送回访问候失败: {}", e)
else:
# 不触发回访问候时,仍可下发 chips 以减少冷启动门槛
await _maybe_send_topic_chips(reason="resume")
while True:
try:
if websocket.application_state != WebSocketState.CONNECTED:
logger.debug(
"WebSocket 已非连接状态,退出循环: conversation_id={}",
conversation_id,
)
break
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.AUDIO_SEGMENT:
_d = message.get("data") or {}
logger.info(
"WebSocket 收到消息 type={} conversation_id={} "
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
msg_type,
conversation_id,
_d.get("segment_index"),
bool(_d.get("is_last")),
int(_d.get("duration") or 0),
len(_d.get("audio_base64") or ""),
)
elif msg_type is not None:
logger.info(
"WebSocket 收到消息 type={} conversation_id={}",
msg_type,
conversation_id,
)
else:
logger.warning(
"WebSocket 收到缺少 type 的 JSON conversation_id={}",
conversation_id,
)
if msg_type == MessageType.TEXT:
text_message = message.get("data", {}).get("text", "")
if text_message:
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
segment = await conversation_service.create_user_segment(
conversation,
user_id,
text_message,
)
user_message_timestamp = conversation.last_message_at
await memoir_ingest_scheduler.queue_segment(
conversation.user_id,
segment.id,
text_char_count=len(text_message.strip()),
)
await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
conversation=conversation,
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at
or user_message_timestamp,
)
elif msg_type == MessageType.RECORDING_STARTED:
data = message.get("data", {})
raw_vs = data.get("voice_session_id")
if not raw_vs or not str(raw_vs).strip():
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 voice_session_id"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
voice_session_id = str(raw_vs).strip()
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
async with segment_state.lock:
if (
segment_state.listening_feedback_task is not None
and not segment_state.listening_feedback_task.done()
):
continue
if segment_state.listening_feedback_sent:
continue
delayed_task = asyncio.create_task(
_delayed_listening_feedback(
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
)
segment_state.listening_feedback_task = delayed_task
elif msg_type == MessageType.AUDIO_SEGMENT:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
segment_index_raw = data.get("segment_index")
resolved_vs = data.get("voice_session_id") or (
_voice_session_id_from_client_segment_id(
data.get("client_segment_id")
)
)
if not resolved_vs or not str(resolved_vs).strip():
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "缺少 voice_session_id 或有效的 client_segment_id"
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
voice_session_id = str(resolved_vs).strip()
is_last = bool(data.get("is_last", False))
audio_duration = int(data.get("duration", 0) or 0)
if not audio_base64:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index_raw is None:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 segment_index"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
segment_index = int(segment_index_raw)
except (TypeError, ValueError):
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 必须为整数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
if segment_index < 0:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "segment_index 不能为负数"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
segment_state = get_or_create_segment_state(
conversation_id,
voice_session_id,
)
should_process = False
async with segment_state.lock:
already_seen = (
segment_index in segment_state.pending_indices
or segment_index in segment_state.processed_indices
or segment_index <= segment_state.consumed_index
)
if not already_seen:
segment_state.pending_indices.add(segment_index)
should_process = True
if not should_process:
logger.debug(
"收到重复分段,跳过: conversation_id={} voice_session_id={} "
"segment_index={} audio_b64_len={} duration={}",
conversation_id,
voice_session_id,
segment_index,
len(audio_base64 or ""),
audio_duration,
)
continue
if is_last:
async with segment_state.lock:
t = segment_state.listening_feedback_task
segment_state.listening_feedback_task = None
if t is not None and not t.done():
t.cancel()
task = asyncio.create_task(
process_audio_segment(
conversation_id=conversation_id,
user_id=user_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
audio_base64=audio_base64,
audio_duration=audio_duration,
is_last=is_last,
)
)
register_segment_task(conversation_id, voice_session_id, task)
elif msg_type == MessageType.AUDIO_MESSAGE:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
audio_duration = data.get("duration", 0)
if audio_base64:
can_send, quota_msg = await check_ws_quota(
quota_service, user_id, user.subscription_type
)
if not can_send:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": quota_msg,
"code": "QUOTA_EXCEEDED",
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
continue
logger.debug(
"收到音频消息: conversation_id={} duration_s={}",
conversation_id,
audio_duration,
)
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
asr_text = await asr.transcribe(audio_bytes, "m4a")
logger.debug(
"ASR 转写完成: conversation_id={} chars={}",
conversation_id,
len(asr_text or ""),
)
logger.debug(
"ASR 转写全文: conversation_id={} text={}",
conversation_id,
asr_text,
)
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": asr_text,
"audio_duration": audio_duration,
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
try:
ads = int(audio_duration)
except (TypeError, ValueError):
ads = 0
segment = (
await conversation_service.create_user_segment(
conversation,
user_id,
asr_text,
audio_url=f"audio:{audio_duration}s",
audio_duration_seconds=ads if ads > 0 else None,
)
)
user_message_timestamp = conversation.last_message_at
await memoir_ingest_scheduler.queue_segment(
conversation.user_id,
segment.id,
text_char_count=len((asr_text or "").strip()),
)
if asr_text and not asr_text.startswith("转写失败"):
await process_user_message(
conversation_id=conversation_id,
user_message=asr_text,
conversation=conversation,
segment=segment,
db=db,
user=user,
user_message_timestamp=segment.created_at
or user_message_timestamp,
)
else:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音转写失败,请重试或使用文字输入"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as e:
logger.exception("处理音频消息失败: {}", e)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "语音处理失败,请重试或使用文字输入"
},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
elif msg_type == MessageType.TRANSCRIBE_ONLY:
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
if not audio_base64:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
continue
try:
asr = get_asr_provider()
audio_bytes = base64.b64decode(audio_base64)
asr_text = await asr.transcribe(audio_bytes, "m4a")
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": asr_text or ""},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as e:
logger.exception("仅转写失败: {}", e)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "语音转写失败,请重试"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
elif msg_type == MessageType.TTS_CANCEL:
bump_tts_cancel_epoch(conversation_id)
elif msg_type == MessageType.END_CONVERSATION:
await conversation_service.end(conversation_id, user_id)
await process_conversation_segments(
conversation_id, db, quota_service
)
await manager.send_message(
conversation_id,
{
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
break
elif msg_type == MessageType.PING:
await manager.send_message(
conversation_id,
{
"type": MessageType.PONG,
"conversation_id": conversation_id,
"data": {},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
else:
if msg_type is not None:
logger.warning(
"WebSocket 未识别的消息 type={} conversation_id={}",
msg_type,
conversation_id,
)
except RuntimeError as e:
error_msg = str(e)
if (
"disconnect" in error_msg.lower()
or 'Cannot call "receive"' in error_msg
or "accept" in error_msg.lower()
and "not connected" in error_msg.lower()
):
logger.debug(
"WebSocket 连接已断开或未就绪: conversation_id={} error={}",
conversation_id,
error_msg,
)
break
else:
logger.exception("处理消息时发生 RuntimeError: {}", e)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "处理失败,请重试"},
"timestamp": datetime.now(
timezone.utc
).isoformat(),
},
)
except Exception as send_error:
logger.warning("发送错误消息失败: {}", send_error)
break
except WebSocketDisconnect as disc:
logger.info(
"WebSocket 断开连接(收消息循环): conversation_id={} code={}",
conversation_id,
getattr(disc, "code", None),
)
break
except Exception as e:
logger.exception("处理消息时发生错误: {}", e)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "处理失败,请重试"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as send_error:
logger.warning("发送错误消息失败: {}", send_error)
break
except WebSocketDisconnect as disc:
logger.info(
"WebSocket 断开连接: conversation_id={} code={}",
conversation_id,
getattr(disc, "code", None),
)
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)
except Exception as e:
logger.exception("WebSocket 端点发生错误: {}", e)
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)
finally:
await manager.disconnect(conversation_id)
cleanup_segment_states(conversation_id)