Files
life-echo/api/test_conversation.py
penghanyuan 3f899aa16c feat: 新增任务状态管理和API支持
- 添加任务状态API路由,支持获取当前用户的任务状态和待处理任务列表
- 实现任务追踪服务,使用Redis存储任务状态
- 更新回忆录处理逻辑,集成Celery任务提交和状态更新
- 增强测试用例,支持任务状态的获取和清除功能
- 优化代码结构,提升可读性和维护性
2026-01-21 23:37:00 +01:00

377 lines
14 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
多轮对话测试脚本
测试对话引导 Agent 和回忆录整理功能
"""
import asyncio
import json
import uuid
import httpx
import websockets
from datetime import datetime
# 配置
BASE_URL = "http://localhost:8000"
WS_URL = "ws://localhost:8000"
# 测试用户信息
TEST_PHONE = f"138{uuid.uuid4().hex[:8]}" # 随机手机号避免冲突
TEST_PASSWORD = "test123456"
TEST_NICKNAME = "测试用户"
# 模拟用户的多轮对话内容(关于童年和教育阶段)
CONVERSATION_MESSAGES = [
# 童年阶段
"我出生在南方一个小镇,小时候跟奶奶住在一起。",
"奶奶家有个小院子,夏天的时候我们经常坐在院子里乘凉,她给我讲故事。",
"那段时光真的很美好,我记得奶奶总是给我做红烧肉,那是我最爱吃的菜。",
"小时候最开心的事就是过年,可以放鞭炮,还能收到压岁钱。",
# 教育阶段
"后来我去城里上学了,那是我第一次离开家,心里特别害怕。",
"初中的时候遇到了一个很好的语文老师,她鼓励我多读书,对我影响很大。",
"高考那年压力特别大,但最后还是考上了理想的大学。",
]
class ConversationTester:
"""对话测试器"""
def __init__(self):
self.token = None
self.user_id = None
self.conversation_id = str(uuid.uuid4())
async def register_or_login(self):
"""注册或登录用户"""
async with httpx.AsyncClient(timeout=30.0) as client:
# 先尝试注册
print(f"\n📝 注册用户: {TEST_PHONE}")
resp = await client.post(
f"{BASE_URL}/api/auth/register",
json={
"phone": TEST_PHONE,
"password": TEST_PASSWORD,
"nickname": TEST_NICKNAME
}
)
if resp.status_code == 201:
data = resp.json()
self.token = data["access_token"]
print(f"✅ 注册成功!")
elif resp.status_code == 400 and "已被注册" in resp.text:
# 已注册,尝试登录
print(f" 用户已存在,尝试登录...")
resp = await client.post(
f"{BASE_URL}/api/auth/login",
json={
"phone": TEST_PHONE,
"password": TEST_PASSWORD
}
)
if resp.status_code == 200:
data = resp.json()
self.token = data["access_token"]
print(f"✅ 登录成功!")
else:
raise Exception(f"登录失败: {resp.text}")
else:
raise Exception(f"注册失败: {resp.text}")
print(f"🔑 Token: {self.token[:30]}...")
async def get_memoir_state(self):
"""获取回忆录状态"""
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.get(
f"{BASE_URL}/api/memoir-state",
headers={"Authorization": f"Bearer {self.token}"}
)
if resp.status_code != 200:
print(f" ⚠️ 状态API返回 {resp.status_code}: {resp.text[:200]}")
return {"current_stage": "unknown", "covered_stages": [], "slots": {}}
return resp.json()
async def get_chapters(self):
"""获取章节列表"""
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.get(
f"{BASE_URL}/api/chapters",
headers={"Authorization": f"Bearer {self.token}"}
)
if resp.status_code != 200:
print(f" ⚠️ 章节API返回 {resp.status_code}: {resp.text[:200]}")
return []
return resp.json()
async def get_book(self):
"""获取回忆录信息"""
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.get(
f"{BASE_URL}/api/books/current",
headers={"Authorization": f"Bearer {self.token}"}
)
if resp.status_code != 200:
print(f" ⚠️ 回忆录API返回 {resp.status_code}: {resp.text[:200]}")
return {"message": "获取失败"}
return resp.json()
async def get_tasks_status(self):
"""获取任务状态"""
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.get(
f"{BASE_URL}/api/tasks/status",
headers={"Authorization": f"Bearer {self.token}"}
)
if resp.status_code != 200:
return {"total": 0, "all_completed": True, "tasks": []}
return resp.json()
async def clear_tasks(self):
"""清除任务记录"""
async with httpx.AsyncClient(timeout=60.0) as client:
await client.delete(
f"{BASE_URL}/api/tasks/clear",
headers={"Authorization": f"Bearer {self.token}"}
)
async def run_conversation(self):
"""运行多轮对话"""
print(f"\n🔗 连接 WebSocket: {self.conversation_id}")
ws_url = f"{WS_URL}/ws/conversation/{self.conversation_id}?token={self.token}"
async with websockets.connect(ws_url) as ws:
# 接收连接确认
msg = await ws.recv()
data = json.loads(msg)
print(f"✅ 连接成功: {data['type']}")
# 多轮对话
for i, user_message in enumerate(CONVERSATION_MESSAGES, 1):
print(f"\n{'='*60}")
print(f"📤 第 {i} 轮对话")
print(f"{'='*60}")
print(f"👤 用户: {user_message}")
# 发送消息
await ws.send(json.dumps({
"type": "text",
"data": {"text": user_message}
}))
# 接收 Agent 回复(可能是多条消息)
try:
while True:
msg = await asyncio.wait_for(ws.recv(), timeout=30)
data = json.loads(msg)
if data["type"] == "agent_response":
msg_data = data['data']
total = msg_data.get('total', 1)
index = msg_data.get('index', 0)
print(f"🤖 Agent: {msg_data['text']}")
# 如果是最后一条消息,退出循环
if index >= total - 1:
break
elif data["type"] == "error":
print(f"❌ 错误: {data['data']['message']}")
break
else:
break
except asyncio.TimeoutError:
print("⏰ 等待响应超时")
# 短暂等待,模拟真实对话节奏
await asyncio.sleep(1)
# 结束对话
print(f"\n{'='*60}")
print("📭 结束对话")
print(f"{'='*60}")
await ws.send(json.dumps({
"type": "end_conversation",
"conversation_id": self.conversation_id
}))
try:
# 结束时会触发 process_conversation_segments可能需要更长时间
msg = await asyncio.wait_for(ws.recv(), timeout=60)
data = json.loads(msg)
if data['type'] == 'error':
print(f"❌ 结束对话错误: {data['data'].get('message', 'unknown')}")
else:
print(f"✅ 对话结束: {data['type']}")
except asyncio.TimeoutError:
print("⏰ 等待结束确认超时(但后台处理可能仍在进行)")
async def wait_for_processing(self, max_wait_seconds: int = 300, check_interval: int = 3):
"""
等待后台处理完成
通过查询 Celery 任务状态来判断处理是否完成
Args:
max_wait_seconds: 最大等待时间(秒),默认 5 分钟
check_interval: 检查间隔(秒)
Returns:
是否在超时前完成
"""
print(f"\n⏳ 等待后台任务完成(最多 {max_wait_seconds} 秒)...")
print(" 提示: 通过 Celery 任务状态 API 追踪任务进度")
start_time = asyncio.get_event_loop().time()
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed >= max_wait_seconds:
print(f"\n⚠️ 已等待 {max_wait_seconds} 秒,超时退出")
return False
# 检查任务状态
tasks_status = await self.get_tasks_status()
total = tasks_status.get("total", 0)
pending = tasks_status.get("pending", 0)
running = tasks_status.get("running", 0)
success = tasks_status.get("success", 0)
failure = tasks_status.get("failure", 0)
all_completed = tasks_status.get("all_completed", False)
# 同时检查章节内容
chapters = await self.get_chapters()
chapter_count = len(chapters)
total_content_length = sum(len(ch.get('content', '')) for ch in chapters)
status_str = f"📊 总:{total} 等待:{pending} 运行:{running} 成功:{success} 失败:{failure}"
content_str = f"📚 章节:{chapter_count} 内容:{total_content_length}字符"
print(f" [{int(elapsed):3d}s] {status_str} | {content_str}")
# 判断是否完成:
# 1. 有任务且全部完成
# 2. 或者没有任务但有章节内容(兼容旧逻辑)
if total > 0 and all_completed:
print(f"\n✅ 所有任务已完成!共 {total} 个任务,等待 {int(elapsed)}")
return True
# 如果没有任务记录,等待一会儿任务提交
if total == 0 and elapsed < 15:
await asyncio.sleep(check_interval)
continue
# 如果长时间没有任务但有内容,也认为完成
if total == 0 and chapter_count > 0 and elapsed > 30:
print(f"\n✅ 无待处理任务,已有 {chapter_count} 个章节。等待 {int(elapsed)}")
return True
await asyncio.sleep(check_interval)
async def check_results(self):
"""检查回忆录生成结果"""
print(f"\n{'='*60}")
print("📊 检查结果")
print(f"{'='*60}")
# 等待后台处理完成(使用智能轮询)
await self.wait_for_processing(max_wait_seconds=180, check_interval=5)
# 获取回忆录状态
print("\n📋 回忆录状态:")
state = await self.get_memoir_state()
print(f" 当前阶段: {state.get('current_stage', 'N/A')}")
print(f" 已完成阶段: {state.get('covered_stages', [])}")
# 显示已填充的 slots
slots = state.get('slots', {})
for stage, stage_slots in slots.items():
filled = [k for k, v in stage_slots.items() if v.get('snippet')]
if filled:
print(f" {stage} 已填充: {filled}")
for slot_name in filled:
snippet = stage_slots[slot_name].get('snippet', '')
if snippet:
print(f" - {slot_name}: {snippet[:50]}...")
# 获取章节
print("\n📚 生成的章节:")
chapters = await self.get_chapters()
if chapters:
for ch in chapters:
is_new = "🆕" if ch.get("is_new") else ""
content_len = len(ch.get('content', ''))
print(f" {is_new} [{ch.get('category', 'N/A')}] {ch.get('title', 'N/A')} ({content_len} 字符)")
else:
print(" (暂无章节)")
# 获取回忆录
print("\n📖 回忆录信息:")
book = await self.get_book()
if "message" not in book:
print(f" 标题: {book.get('title', 'N/A')}")
print(f" 总字数: {book.get('total_words', 0)}")
print(f" 有更新: {'' if book.get('has_update') else ''}")
else:
print(f" {book.get('message', 'N/A')}")
# 显示回忆录完整内容
if chapters:
print(f"\n{'='*60}")
print("📜 回忆录完整内容")
print(f"{'='*60}")
for ch in chapters:
category = ch.get('category', 'N/A')
title = ch.get('title', '未命名章节')
content = ch.get('content', '')
print(f"\n{''*60}")
print(f"{title}】({category})")
print(f"{''*60}")
if content:
print(content)
else:
print("(暂无内容)")
print(f"\n{'='*60}")
async def main():
"""主函数"""
print("=" * 60)
print("🎭 Life Echo 多轮对话测试")
print(f"⏰ 开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 60)
tester = ConversationTester()
try:
# 1. 注册/登录
await tester.register_or_login()
# 2. 清除旧的任务记录
await tester.clear_tasks()
print("\n🧹 已清除旧的任务记录")
# 3. 查看初始状态
print("\n📋 初始回忆录状态:")
state = await tester.get_memoir_state()
print(f" 当前阶段: {state.get('current_stage', 'N/A')}")
# 4. 运行多轮对话
await tester.run_conversation()
# 5. 检查结果
await tester.check_results()
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print(f"⏰ 结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())