agent init
This commit is contained in:
@@ -138,30 +138,48 @@ async def register(
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
request: LoginRequest,
|
||||
request: LoginRequest = None,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""
|
||||
用户登录
|
||||
|
||||
验证手机号和密码,返回访问令牌和刷新令牌
|
||||
验证手机号和密码,返回访问令牌和刷新令牌。
|
||||
|
||||
支持两种格式:
|
||||
- JSON: {"phone": "13800138000", "password": "xxx"}
|
||||
- 表单 (Swagger UI): username=13800138000&password=xxx
|
||||
"""
|
||||
# 优先使用表单数据(Swagger UI),否则使用 JSON
|
||||
if form_data and form_data.username:
|
||||
phone = form_data.username
|
||||
password = form_data.password
|
||||
elif request:
|
||||
phone = request.phone
|
||||
password = request.password
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="请提供手机号和密码"
|
||||
)
|
||||
|
||||
# 验证手机号格式(简单验证)
|
||||
if not request.phone or len(request.phone) != 11 or not request.phone.isdigit():
|
||||
if not phone or len(phone) != 11 or not phone.isdigit():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="手机号格式不正确,应为11位数字"
|
||||
)
|
||||
|
||||
# 验证密码不为空
|
||||
if not request.password:
|
||||
if not password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="密码不能为空"
|
||||
)
|
||||
|
||||
# 查找用户
|
||||
stmt = select(User).where(User.phone == request.phone)
|
||||
stmt = select(User).where(User.phone == phone)
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
@@ -172,7 +190,7 @@ async def login(
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(request.password, user.password_hash):
|
||||
if not verify_password(password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="手机号或密码错误"
|
||||
|
||||
@@ -33,10 +33,28 @@ async def get_current_book(
|
||||
"title": book.title,
|
||||
"total_pages": book.total_pages,
|
||||
"total_words": book.total_words,
|
||||
"cover_image_url": book.cover_image_url
|
||||
"cover_image_url": book.cover_image_url,
|
||||
"has_update": book.has_update,
|
||||
"last_update_chapter_id": book.last_update_chapter_id,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/clear-update")
|
||||
async def clear_book_update(
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
):
|
||||
"""清除回忆录更新标记"""
|
||||
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
book = result.scalar_one_or_none()
|
||||
if not book:
|
||||
return {"status": "ok", "message": "No book found"}
|
||||
book.has_update = False
|
||||
await db.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class ExportPdfRequest(BaseModel):
|
||||
book_id: str
|
||||
|
||||
|
||||
@@ -18,10 +18,14 @@ router = APIRouter(prefix="/api/chapters", tags=["chapters"])
|
||||
@router.get("", response_model=List[dict])
|
||||
async def get_chapters(
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
is_new: Optional[bool] = Query(None, description="仅返回未读章节"),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""获取用户所有章节(需要认证)"""
|
||||
stmt = select(ChapterModel).where(ChapterModel.user_id == current_user.id).order_by(ChapterModel.order_index)
|
||||
stmt = select(ChapterModel).where(ChapterModel.user_id == current_user.id)
|
||||
if is_new is True:
|
||||
stmt = stmt.where(ChapterModel.is_new == True)
|
||||
stmt = stmt.order_by(ChapterModel.order_index)
|
||||
result = await db.execute(stmt)
|
||||
chapters = result.scalars().all()
|
||||
|
||||
@@ -33,7 +37,10 @@ async def get_chapters(
|
||||
"order_index": ch.order_index,
|
||||
"status": ch.status,
|
||||
"category": ch.category,
|
||||
"images": ch.images or []
|
||||
"images": ch.images or [],
|
||||
"updated_at": ch.updated_at.isoformat() if ch.updated_at else None,
|
||||
"is_new": ch.is_new,
|
||||
"source_segments": ch.source_segments or [],
|
||||
}
|
||||
for ch in chapters
|
||||
]
|
||||
@@ -61,7 +68,10 @@ async def get_chapter(
|
||||
"order_index": chapter.order_index,
|
||||
"status": chapter.status,
|
||||
"category": chapter.category,
|
||||
"images": chapter.images or []
|
||||
"images": chapter.images or [],
|
||||
"updated_at": chapter.updated_at.isoformat() if chapter.updated_at else None,
|
||||
"is_new": chapter.is_new,
|
||||
"source_segments": chapter.source_segments or [],
|
||||
}
|
||||
|
||||
|
||||
|
||||
61
api/routers/memoir_state.py
Normal file
61
api/routers/memoir_state.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
回忆录状态 API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_async_db
|
||||
from database.models import Book as BookModel
|
||||
from database.models import Chapter as ChapterModel
|
||||
from database.models import User as UserModel
|
||||
from middleware.auth import get_current_user
|
||||
from services.memoir_state_service import get_or_create_state
|
||||
|
||||
router = APIRouter(prefix="/api/memoir-state", tags=["memoir-state"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_memoir_state(
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
):
|
||||
"""获取当前用户回忆录状态"""
|
||||
state = await get_or_create_state(current_user.id, db)
|
||||
return state.model_dump()
|
||||
|
||||
|
||||
@router.get("/next-question")
|
||||
async def get_next_question_context(
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
):
|
||||
"""获取下一步问题的上下文(当前阶段与空 slot)"""
|
||||
state = await get_or_create_state(current_user.id, db)
|
||||
return {
|
||||
"current_stage": state.current_stage,
|
||||
"empty_slots": state.empty_slots_for_current_stage(),
|
||||
"covered_stages": state.covered_stages,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/mark-read")
|
||||
async def mark_memoir_read(
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
):
|
||||
"""标记回忆录更新已读"""
|
||||
stmt = select(ChapterModel).where(ChapterModel.user_id == current_user.id, ChapterModel.is_new == True)
|
||||
result = await db.execute(stmt)
|
||||
chapters = result.scalars().all()
|
||||
for chapter in chapters:
|
||||
chapter.is_new = False
|
||||
|
||||
stmt_book = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc())
|
||||
result_book = await db.execute(stmt_book)
|
||||
book = result_book.scalar_one_or_none()
|
||||
if book:
|
||||
book.has_update = False
|
||||
|
||||
await db.commit()
|
||||
return {"status": "ok"}
|
||||
@@ -11,11 +11,12 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from agents import ConversationAgent, MemoryAgent
|
||||
from agents.prompts import ConversationStage
|
||||
from agents.memoir_processor import BackgroundTaskRunner
|
||||
from database import get_async_db
|
||||
from database.models import Conversation, Segment
|
||||
from database.models import User as UserModel
|
||||
from services.auth_service import verify_token
|
||||
from services.memoir_state_service import get_or_create_state
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
@@ -28,6 +29,7 @@ class MessageType(str, Enum):
|
||||
AGENT_RESPONSE = "agent_response"
|
||||
TTS_AUDIO = "tts_audio"
|
||||
END_CONVERSATION = "end_conversation"
|
||||
MEMOIR_UPDATE = "memoir_update"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@@ -39,6 +41,7 @@ class ConnectionManager:
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.conversation_agents: Dict[str, ConversationAgent] = {}
|
||||
self.memory_agent = MemoryAgent()
|
||||
self.background_runner = BackgroundTaskRunner()
|
||||
|
||||
async def connect(self, websocket: WebSocket, conversation_id: str):
|
||||
"""建立连接"""
|
||||
@@ -137,8 +140,6 @@ async def websocket_endpoint(
|
||||
return
|
||||
|
||||
|
||||
current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD
|
||||
|
||||
# 主循环:处理消息
|
||||
while True:
|
||||
try:
|
||||
@@ -159,12 +160,13 @@ async def websocket_endpoint(
|
||||
)
|
||||
db.add(segment)
|
||||
await db.commit()
|
||||
await db.refresh(segment)
|
||||
await manager.background_runner.queue_message(conversation.user_id, segment.id)
|
||||
|
||||
# Agent 生成回应
|
||||
current_stage = await process_user_message(
|
||||
await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=text_message,
|
||||
current_stage=current_stage,
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
@@ -206,19 +208,17 @@ async def websocket_endpoint(
|
||||
async def process_user_message(
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
current_stage: ConversationStage,
|
||||
conversation: Conversation,
|
||||
segment: Segment,
|
||||
db: AsyncSession,
|
||||
manager: ConnectionManager
|
||||
) -> ConversationStage:
|
||||
) -> None:
|
||||
"""
|
||||
处理用户消息,生成Agent回应
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
user_message: 用户消息文本
|
||||
current_stage: 当前对话阶段
|
||||
conversation: 对话对象
|
||||
segment: 段落对象
|
||||
db: 数据库会话
|
||||
@@ -229,14 +229,13 @@ async def process_user_message(
|
||||
"""
|
||||
agent = manager.conversation_agents.get(conversation_id)
|
||||
if agent:
|
||||
# 检测对话阶段
|
||||
detected_stage = agent.detect_stage(conversation_id, user_message)
|
||||
if detected_stage != current_stage:
|
||||
current_stage = detected_stage
|
||||
conversation.conversation_stage = current_stage.value
|
||||
state = await get_or_create_state(conversation.user_id, db)
|
||||
|
||||
if conversation.conversation_stage != state.current_stage:
|
||||
conversation.conversation_stage = state.current_stage
|
||||
await db.commit()
|
||||
|
||||
# 获取已聊话题
|
||||
|
||||
# 获取已聊话题(保留老逻辑用于提示)
|
||||
stmt_segments = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id
|
||||
).order_by(Segment.created_at)
|
||||
@@ -244,37 +243,49 @@ async def process_user_message(
|
||||
previous_segments = result_segments.scalars().all()
|
||||
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
||||
|
||||
# 生成回应
|
||||
response = agent.generate_response(
|
||||
# 生成回应(可能是多条消息)
|
||||
responses = agent.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
current_stage=current_stage,
|
||||
covered_topics=covered_topics
|
||||
memoir_state=state
|
||||
)
|
||||
|
||||
# 更新段落的 Agent 回应
|
||||
segment.agent_response = response
|
||||
# 更新段落的 Agent 回应(存储完整内容)
|
||||
segment.agent_response = "\n\n".join(responses)
|
||||
await db.commit()
|
||||
|
||||
# 发送 Agent 回应(仅文字,不生成语音)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": response},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
# 发送 Agent 回应(支持多条消息)
|
||||
import asyncio as _asyncio
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
# 多条消息之间稍作间隔,模拟打字效果
|
||||
if i < len(responses) - 1:
|
||||
await _asyncio.sleep(0.5)
|
||||
|
||||
return current_stage
|
||||
return
|
||||
|
||||
|
||||
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
"""
|
||||
处理对话段落,生成章节
|
||||
处理对话段落,生成章节(对话结束时调用)
|
||||
|
||||
注意:大部分处理已通过 BackgroundTaskRunner 增量完成
|
||||
这里只处理可能遗漏的最后几条消息
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
db: 数据库会话
|
||||
"""
|
||||
# 获取对话信息
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
# 获取所有未处理的段落
|
||||
stmt = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id,
|
||||
@@ -286,39 +297,7 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
if not segments:
|
||||
return
|
||||
|
||||
# 准备段落数据
|
||||
segments_data = [
|
||||
{"transcript_text": seg.transcript_text}
|
||||
for seg in segments
|
||||
]
|
||||
|
||||
# 调用整理 Agent
|
||||
memory_agent = manager.memory_agent
|
||||
chapters_data = memory_agent.process_segments(segments_data)
|
||||
|
||||
# 保存章节到数据库
|
||||
from database.models import Chapter as ChapterModel
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
for category, chapter_data in chapters_data.items():
|
||||
chapter = ChapterModel(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=conversation.user_id,
|
||||
title=chapter_data.get("title", f"章节-{category}"),
|
||||
content=chapter_data.get("content", ""),
|
||||
order_index=chapter_data.get("order_index", 999),
|
||||
status="completed",
|
||||
category=category,
|
||||
images=chapter_data.get("image_suggestions", [])
|
||||
)
|
||||
db.add(chapter)
|
||||
|
||||
# 标记段落为已处理
|
||||
# 将未处理的段落加入后台任务队列(不等待完成,避免阻塞)
|
||||
for seg in segments:
|
||||
seg.processed = True
|
||||
|
||||
await db.commit()
|
||||
await manager.background_runner.queue_message(conversation.user_id, seg.id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user