refactor(api,expo): 多智能体与会话收敛、回忆录兼容层移除、后端测试集大幅删减
- 对齐「多智能体收敛」与「回忆录 stories-first / markdown-first」方向:收紧运行时契约、 删除过渡兼容路径与双轨逻辑,并同步更新客户端与文档。 - Chat:以 ChatOrchestrator 为实时编排入口;删除独立 conversation_agent,精简 prompts。 - Memoir:删除 memory_agent;MemoirOrchestrator、classification / story_route 与 prompts 收敛到 prepare_batches + run_story_pipeline_for_category_batch 主链路。 - 将 agents 侧 processor 迁入 feature 层为 background_runner,并移除 features 下重复/过时 processor 封装。 - 新增 history_store,强化「conversation_messages 为 DB 真源、Redis 为缓存」模型。 - 调整 models、repo、service、session_history;精简 WS message_types,重构 pipeline 与 router。 - 移除章节占位、整章再生等旧路径;章节列表与封面逻辑要求 story 关联;收紧 cover 资格与 enqueue。 - helpers、repo、service、router、reading_segment_materialize、story_pipeline_sync、pdf_service 等按 canonical markdown / cover_asset_id 收缩;删除 memoir_images/provider 等冗余。 - tasks:memoir_tasks、chapter_cover_tasks 等大幅瘦身;story_image_tasks 等与当前图片任务对齐。 - core:config、logging、redis、task_tracker 小幅调整。 - auth / user / payment / quota:路由或服务侧删减过时接口或逻辑(如 payment router 行数减少)。 - pyproject.toml、development.sh、.env.example / .env.production、README 等同步说明或变量。 - Alembic 0001_initial_schema 微调(与当前 schema 叙事一致的小改动)。 - 回忆录:types / mappers / api、章节页与 memoir 页与后端契约对齐;markdown-renderer 调整。 - 语音:删除 voice/player,voice-segment-store 相应精简。 - api/tests:删除 conftest 及绝大部分既有测试文件(websocket_baseline、conversation、memoir 图片、PDF、SMS 等),属有意收缩/待按 backend-test-system 重建的信号。 - docs:新增多智能体收敛与移除兼容层计划摘要;更新 story-first 设计、backend-test-system、 multi-agent-refactor-plan、实施总结等。 BREAKING CHANGE: 后端对外契约、回忆录章节字段与若干路由/任务行为已变更;大量 API 测试被移除, CI 若依赖这些用例需按新策略补测或调整流水线。
This commit is contained in:
@@ -2,8 +2,6 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
LEGACY_VOICE_SESSION_ID = "legacy"
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""WebSocket 消息类型"""
|
||||
|
||||
@@ -15,24 +15,20 @@ if TYPE_CHECKING:
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents import ConversationAgent, MemoryAgent
|
||||
from app.agents.chat import ChatOrchestrator
|
||||
from app.agents.memoir import BackgroundTaskRunner
|
||||
from app.core.config import settings
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
|
||||
from app.core.redis import redis_service
|
||||
from app.features.conversation.history_store import ConversationHistoryStore
|
||||
from app.features.conversation.models import Conversation, Segment
|
||||
from app.features.conversation.ws.connection_manager import manager
|
||||
from app.features.conversation.ws.message_types import (
|
||||
LEGACY_VOICE_SESSION_ID,
|
||||
MessageType,
|
||||
)
|
||||
from app.features.conversation.ws.message_types import MessageType
|
||||
from app.features.conversation.ws.profile_collector import (
|
||||
apply_extracted_profile,
|
||||
get_filled_profile_fields,
|
||||
get_missing_profile_fields,
|
||||
)
|
||||
from app.features.memoir.background_runner import BackgroundTaskRunner
|
||||
from app.features.user.models import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -77,9 +73,6 @@ async def _send_tts_audio(
|
||||
storage = get_object_storage()
|
||||
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
|
||||
public_url = storage.upload(key, audio_bytes, content_type)
|
||||
await redis_service.append_tts_audio_url_to_last_ai_message(
|
||||
conversation_id, public_url
|
||||
)
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
@@ -109,9 +102,7 @@ async def _send_tts_audio(
|
||||
|
||||
|
||||
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
|
||||
conversation_agent = ConversationAgent()
|
||||
chat_orchestrator = ChatOrchestrator()
|
||||
memory_agent = MemoryAgent()
|
||||
background_runner = BackgroundTaskRunner()
|
||||
|
||||
|
||||
@@ -197,12 +188,6 @@ def _mark_conversation_active(
|
||||
return activity_time
|
||||
|
||||
|
||||
def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
|
||||
if voice_session_id:
|
||||
return str(voice_session_id)
|
||||
return LEGACY_VOICE_SESSION_ID
|
||||
|
||||
|
||||
def _voice_session_id_from_client_segment_id(
|
||||
client_segment_id: Optional[str],
|
||||
) -> Optional[str]:
|
||||
@@ -220,19 +205,19 @@ def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
|
||||
|
||||
|
||||
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
|
||||
"""从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}。"""
|
||||
"""从 audio_url 解析 voice_session_id 与 segment_index(audio-segment:{session_id}:{index})。"""
|
||||
prefix = "audio-segment:"
|
||||
if not audio_url or not audio_url.startswith(prefix):
|
||||
return None
|
||||
payload = audio_url[len(prefix) :]
|
||||
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
|
||||
if not separator:
|
||||
return None
|
||||
try:
|
||||
if separator:
|
||||
return (
|
||||
_normalize_voice_session_id(voice_session_id_raw),
|
||||
int(segment_index_raw),
|
||||
)
|
||||
return (LEGACY_VOICE_SESSION_ID, int(payload))
|
||||
sid = str(voice_session_id_raw).strip()
|
||||
if not sid:
|
||||
return None
|
||||
return (sid, int(segment_index_raw))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@@ -452,9 +437,14 @@ async def process_audio_segment(
|
||||
if existing_segment:
|
||||
async with state.lock:
|
||||
state.processed_indices.add(segment_index)
|
||||
logger.info(
|
||||
"分段已存在,按幂等处理跳过: "
|
||||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||||
logger.debug(
|
||||
"分段已存在,按幂等跳过: conversation_id=%s voice_session_id=%s "
|
||||
"segment_index=%s segment_id=%s transcript=%s",
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
segment_index,
|
||||
existing_segment.id,
|
||||
existing_segment.transcript_text or "",
|
||||
)
|
||||
return
|
||||
else:
|
||||
@@ -535,6 +525,8 @@ async def process_user_message(
|
||||
user_message_timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
"""处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。"""
|
||||
store = ConversationHistoryStore(db)
|
||||
tts_urls: list[str] = []
|
||||
try:
|
||||
is_from_voice = bool(segment.audio_url)
|
||||
voice_session_id = _voice_session_id_from_audio_url(segment.audio_url)
|
||||
@@ -558,9 +550,18 @@ async def process_user_message(
|
||||
|
||||
segment.agent_response = "\n\n".join(responses)
|
||||
_mark_conversation_active(conversation)
|
||||
await db.commit()
|
||||
await store.record_human_ai_turn(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
responses=responses,
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
is_from_voice=is_from_voice,
|
||||
voice_session_id=voice_session_id,
|
||||
audio_duration_seconds=audio_dur,
|
||||
tts_audio_urls=None,
|
||||
segment_id=segment.id,
|
||||
)
|
||||
|
||||
tts_urls: list[str] = []
|
||||
n = len(responses)
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(
|
||||
@@ -589,14 +590,35 @@ async def process_user_message(
|
||||
if i < n - 1:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
await db.execute(
|
||||
update(Segment)
|
||||
.where(Segment.id == segment.id)
|
||||
.values(tts_audio_urls=tts_urls if tts_urls else None)
|
||||
)
|
||||
await db.commit()
|
||||
if tts_urls:
|
||||
await store.attach_ai_tts_audio_urls(
|
||||
conversation_id,
|
||||
tts_audio_urls=tts_urls,
|
||||
segment_id=segment.id,
|
||||
)
|
||||
await db.execute(
|
||||
update(Segment)
|
||||
.where(Segment.id == segment.id)
|
||||
.values(tts_audio_urls=tts_urls)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
if tts_urls:
|
||||
try:
|
||||
await store.attach_ai_tts_audio_urls(
|
||||
conversation_id,
|
||||
tts_audio_urls=tts_urls,
|
||||
segment_id=segment.id,
|
||||
)
|
||||
await db.execute(
|
||||
update(Segment)
|
||||
.where(Segment.id == segment.id)
|
||||
.values(tts_audio_urls=tts_urls)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as persist_error:
|
||||
logger.warning("补写 TTS 元数据失败: %s", persist_error)
|
||||
logger.error(f"处理用户消息失败: {e}", exc_info=True)
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
|
||||
@@ -16,19 +16,18 @@ from app.core.db import AsyncSessionLocal
|
||||
from app.core.dependencies import get_asr_provider
|
||||
from app.core.logging import get_logger
|
||||
from app.core.security import verify_token
|
||||
from app.features.conversation.history_store import ConversationHistoryStore
|
||||
from app.features.conversation.models import Conversation, Segment
|
||||
from app.features.conversation.service import ConversationService
|
||||
from app.features.conversation.ws.connection_manager import manager
|
||||
from app.features.conversation.ws.message_types import MessageType
|
||||
from app.features.conversation.ws.pipeline import (
|
||||
SegmentStreamState, # noqa: F401 — re-export for test backward compat
|
||||
_delayed_listening_feedback,
|
||||
_mark_conversation_active,
|
||||
_normalize_voice_session_id,
|
||||
_voice_session_id_from_client_segment_id,
|
||||
background_runner,
|
||||
chat_orchestrator,
|
||||
cleanup_segment_states,
|
||||
conversation_agent,
|
||||
get_or_create_segment_state,
|
||||
process_audio_segment,
|
||||
process_conversation_segments,
|
||||
@@ -144,18 +143,21 @@ async def websocket_endpoint(
|
||||
)
|
||||
return
|
||||
|
||||
history = await conversation_service.ensure_redis_history_from_segments(
|
||||
history = await conversation_service.ensure_redis_history_from_db(
|
||||
conversation_id
|
||||
)
|
||||
if not history:
|
||||
missing_profile = get_missing_profile_fields(user)
|
||||
if missing_profile:
|
||||
try:
|
||||
greetings = await conversation_agent.generate_profile_greeting(
|
||||
greetings = await chat_orchestrator.generate_profile_greeting(
|
||||
conversation_id=conversation_id,
|
||||
missing_fields=missing_profile,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
await ConversationHistoryStore(db).record_ai_only_turn(
|
||||
conversation_id, greetings
|
||||
)
|
||||
for i, text in enumerate(greetings):
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
@@ -184,12 +186,15 @@ async def websocket_endpoint(
|
||||
occupation=user.occupation,
|
||||
)
|
||||
opening_messages = (
|
||||
await conversation_agent.generate_opening_message(
|
||||
await chat_orchestrator.generate_opening_message(
|
||||
conversation_id=conversation_id,
|
||||
memoir_state=state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
)
|
||||
await ConversationHistoryStore(db).record_ai_only_turn(
|
||||
conversation_id, opening_messages
|
||||
)
|
||||
for i, text in enumerate(opening_messages):
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
@@ -212,8 +217,9 @@ async def websocket_endpoint(
|
||||
while True:
|
||||
try:
|
||||
if websocket.application_state != WebSocketState.CONNECTED:
|
||||
logger.info(
|
||||
f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}"
|
||||
logger.debug(
|
||||
"WebSocket 已非连接状态,退出循环: conversation_id=%s",
|
||||
conversation_id,
|
||||
)
|
||||
break
|
||||
message = await websocket.receive_json()
|
||||
@@ -271,9 +277,18 @@ async def websocket_endpoint(
|
||||
|
||||
elif msg_type == MessageType.RECORDING_STARTED:
|
||||
data = message.get("data", {})
|
||||
voice_session_id = _normalize_voice_session_id(
|
||||
data.get("voice_session_id")
|
||||
)
|
||||
raw_vs = data.get("voice_session_id")
|
||||
if not raw_vs or not str(raw_vs).strip():
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "缺少 voice_session_id"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
voice_session_id = str(raw_vs).strip()
|
||||
segment_state = get_or_create_segment_state(
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
@@ -298,12 +313,24 @@ async def websocket_endpoint(
|
||||
data = message.get("data", {})
|
||||
audio_base64 = data.get("audio_base64", "")
|
||||
segment_index_raw = data.get("segment_index")
|
||||
voice_session_id = _normalize_voice_session_id(
|
||||
data.get("voice_session_id")
|
||||
or _voice_session_id_from_client_segment_id(
|
||||
resolved_vs = data.get("voice_session_id") or (
|
||||
_voice_session_id_from_client_segment_id(
|
||||
data.get("client_segment_id")
|
||||
)
|
||||
)
|
||||
if not resolved_vs or not str(resolved_vs).strip():
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
{
|
||||
"type": MessageType.ERROR,
|
||||
"data": {
|
||||
"message": "缺少 voice_session_id 或有效的 client_segment_id"
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
voice_session_id = str(resolved_vs).strip()
|
||||
is_last = bool(data.get("is_last", False))
|
||||
audio_duration = int(data.get("duration", 0) or 0)
|
||||
|
||||
@@ -386,9 +413,14 @@ async def websocket_endpoint(
|
||||
should_process = True
|
||||
|
||||
if not should_process:
|
||||
logger.info(
|
||||
"收到重复分段,跳过处理: "
|
||||
f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}"
|
||||
logger.debug(
|
||||
"收到重复分段,跳过: conversation_id=%s voice_session_id=%s "
|
||||
"segment_index=%s audio_b64_len=%s duration=%s",
|
||||
conversation_id,
|
||||
voice_session_id,
|
||||
segment_index,
|
||||
len(audio_base64 or ""),
|
||||
audio_duration,
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -437,7 +469,11 @@ async def websocket_endpoint(
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"收到音频消息,时长: {audio_duration}s")
|
||||
logger.debug(
|
||||
"收到音频消息: conversation_id=%s duration_s=%s",
|
||||
conversation_id,
|
||||
audio_duration,
|
||||
)
|
||||
|
||||
try:
|
||||
asr = get_asr_provider()
|
||||
@@ -445,7 +481,16 @@ async def websocket_endpoint(
|
||||
transcript_text = await asr.transcribe(
|
||||
audio_bytes, "m4a"
|
||||
)
|
||||
logger.info("ASR 转写结果: %s", transcript_text)
|
||||
logger.debug(
|
||||
"ASR 转写完成: conversation_id=%s chars=%s",
|
||||
conversation_id,
|
||||
len(transcript_text or ""),
|
||||
)
|
||||
logger.debug(
|
||||
"ASR 转写全文: conversation_id=%s text=%s",
|
||||
conversation_id,
|
||||
transcript_text,
|
||||
)
|
||||
|
||||
await manager.send_message(
|
||||
conversation_id,
|
||||
@@ -591,8 +636,10 @@ async def websocket_endpoint(
|
||||
or "accept" in error_msg.lower()
|
||||
and "not connected" in error_msg.lower()
|
||||
):
|
||||
logger.info(
|
||||
f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}"
|
||||
logger.debug(
|
||||
"WebSocket 连接已断开或未就绪: conversation_id=%s error=%s",
|
||||
conversation_id,
|
||||
error_msg,
|
||||
)
|
||||
break
|
||||
else:
|
||||
@@ -613,8 +660,8 @@ async def websocket_endpoint(
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(
|
||||
f"WebSocket 断开连接: conversation_id={conversation_id}"
|
||||
logger.debug(
|
||||
"WebSocket 断开连接: conversation_id=%s", conversation_id
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -634,7 +681,7 @@ async def websocket_endpoint(
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||||
logger.debug("WebSocket 断开连接: conversation_id=%s", conversation_id)
|
||||
await manager.disconnect(conversation_id)
|
||||
cleanup_segment_states(conversation_id)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user