agent init

This commit is contained in:
penghanyuan
2026-01-21 22:31:03 +01:00
parent 426f23c777
commit 44bd478c1e
19 changed files with 1513 additions and 111 deletions

View File

@@ -11,11 +11,12 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from agents import ConversationAgent, MemoryAgent
from agents.prompts import ConversationStage
from agents.memoir_processor import BackgroundTaskRunner
from database import get_async_db
from database.models import Conversation, Segment
from database.models import User as UserModel
from services.auth_service import verify_token
from services.memoir_state_service import get_or_create_state
from fastapi import HTTPException, status
@@ -28,6 +29,7 @@ class MessageType(str, Enum):
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
END_CONVERSATION = "end_conversation"
MEMOIR_UPDATE = "memoir_update"
ERROR = "error"
@@ -39,6 +41,7 @@ class ConnectionManager:
self.active_connections: Dict[str, WebSocket] = {}
self.conversation_agents: Dict[str, ConversationAgent] = {}
self.memory_agent = MemoryAgent()
self.background_runner = BackgroundTaskRunner()
async def connect(self, websocket: WebSocket, conversation_id: str):
"""建立连接"""
@@ -137,8 +140,6 @@ async def websocket_endpoint(
return
current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD
# 主循环:处理消息
while True:
try:
@@ -159,12 +160,13 @@ async def websocket_endpoint(
)
db.add(segment)
await db.commit()
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
# Agent 生成回应
current_stage = await process_user_message(
await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
current_stage=current_stage,
conversation=conversation,
segment=segment,
db=db,
@@ -206,19 +208,17 @@ async def websocket_endpoint(
async def process_user_message(
conversation_id: str,
user_message: str,
current_stage: ConversationStage,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager
) -> ConversationStage:
) -> None:
"""
处理用户消息生成Agent回应
Args:
conversation_id: 对话ID
user_message: 用户消息文本
current_stage: 当前对话阶段
conversation: 对话对象
segment: 段落对象
db: 数据库会话
@@ -229,14 +229,13 @@ async def process_user_message(
"""
agent = manager.conversation_agents.get(conversation_id)
if agent:
# 检测对话阶段
detected_stage = agent.detect_stage(conversation_id, user_message)
if detected_stage != current_stage:
current_stage = detected_stage
conversation.conversation_stage = current_stage.value
state = await get_or_create_state(conversation.user_id, db)
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
# 获取已聊话题
# 获取已聊话题(保留老逻辑用于提示)
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
@@ -244,37 +243,49 @@ async def process_user_message(
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 生成回应
response = agent.generate_response(
# 生成回应(可能是多条消息)
responses = agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
current_stage=current_stage,
covered_topics=covered_topics
memoir_state=state
)
# 更新段落的 Agent 回应
segment.agent_response = response
# 更新段落的 Agent 回应(存储完整内容)
segment.agent_response = "\n\n".join(responses)
await db.commit()
# 发送 Agent 回应(仅文字,不生成语音
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 发送 Agent 回应(支持多条消息
import asyncio as _asyncio
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 多条消息之间稍作间隔,模拟打字效果
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
return current_stage
return
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
"""
处理对话段落,生成章节
处理对话段落,生成章节(对话结束时调用)
注意:大部分处理已通过 BackgroundTaskRunner 增量完成
这里只处理可能遗漏的最后几条消息
Args:
conversation_id: 对话 ID
db: 数据库会话
"""
# 获取对话信息
conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
# 获取所有未处理的段落
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
@@ -286,39 +297,7 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
if not segments:
return
# 准备段落数据
segments_data = [
{"transcript_text": seg.transcript_text}
for seg in segments
]
# 调用整理 Agent
memory_agent = manager.memory_agent
chapters_data = memory_agent.process_segments(segments_data)
# 保存章节到数据库
from database.models import Chapter as ChapterModel
conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
for category, chapter_data in chapters_data.items():
chapter = ChapterModel(
id=str(uuid.uuid4()),
user_id=conversation.user_id,
title=chapter_data.get("title", f"章节-{category}"),
content=chapter_data.get("content", ""),
order_index=chapter_data.get("order_index", 999),
status="completed",
category=category,
images=chapter_data.get("image_suggestions", [])
)
db.add(chapter)
# 标记段落为已处理
# 将未处理的段落加入后台任务队列(不等待完成,避免阻塞)
for seg in segments:
seg.processed = True
await db.commit()
await manager.background_runner.queue_message(conversation.user_id, seg.id)