Files
life-echo/api/agents/state_schema.py
penghanyuan 7fe0b70d5c feat: 增强对话代理以检测用户阶段并更新章节排序
- 在 api/agents/conversation_agent.py 中添加 _detect_user_stage 方法,以通过关键词检测用户谈论的人生阶段。
- 在 api/agents/memory_agent.py 中更新章节排序逻辑,使用 STAGE_TO_ORDER 替代 CHAPTER_ORDER。
- 在 api/agents/state_schema.py 中添加方法以获取各阶段的填充情况。
- 在 api/agents/prompts/conversation_prompts.py 中更新对话提示,包含用户阶段检测和整体进度信息。
- 在 api/migrations/fix_chapter_order_index.sql 中添加 SQL 脚本以修复章节 order_index 的问题。
- 更新相关文档和提示以反映新功能。
2026-02-13 21:45:56 +01:00

111 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
共享状态 Schema对话 Agent 与后台 Agent 共用)
"""
from __future__ import annotations
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
class SlotData(BaseModel):
"""Slot 数据结构"""
snippet: Optional[str] = None
segment_ids: List[str] = Field(default_factory=list)
class MemoirStateSchema(BaseModel):
"""回忆录状态"""
stage_order: List[str]
current_stage: str
covered_stages: List[str]
slots: Dict[str, Dict[str, SlotData]]
def empty_slots_for_current_stage(self) -> List[str]:
stage_slots = self.slots.get(self.current_stage, {})
empty_keys: List[str] = []
for key, value in stage_slots.items():
if not value.snippet:
empty_keys.append(key)
return empty_keys
def empty_slots_for_stage(self, stage: str) -> List[str]:
"""获取指定阶段的空槽位"""
stage_slots = self.slots.get(stage, {})
return [key for key, value in stage_slots.items() if not value.snippet]
def filled_slots_for_stage(self, stage: str) -> Dict[str, str]:
"""获取指定阶段已填充的槽位及其内容"""
stage_slots = self.slots.get(stage, {})
return {
key: value.snippet
for key, value in stage_slots.items()
if value.snippet
}
def all_stages_coverage(self) -> Dict[str, Dict]:
"""获取所有阶段的覆盖情况摘要"""
coverage: Dict[str, Dict] = {}
for stage in self.stage_order:
stage_slots = self.slots.get(stage, {})
total = len(stage_slots)
filled = sum(1 for v in stage_slots.values() if v.snippet)
coverage[stage] = {
"total": total,
"filled": filled,
"empty": total - filled,
"ratio": filled / total if total > 0 else 0,
}
return coverage
DEFAULT_STAGE_ORDER = ["childhood", "education", "career", "family", "belief"]
def default_slots() -> Dict[str, Dict[str, SlotData]]:
return {
"childhood": {
"place": SlotData(),
"people": SlotData(),
"daily_life": SlotData(),
"emotion": SlotData(),
"turning_event": SlotData(),
},
"education": {
"school": SlotData(),
"city": SlotData(),
"motivation": SlotData(),
"challenge": SlotData(),
"change": SlotData(),
},
"career": {
"job": SlotData(),
"environment": SlotData(),
"decision": SlotData(),
"pressure": SlotData(),
"growth": SlotData(),
},
"family": {
"relationship": SlotData(),
"conflict": SlotData(),
"support": SlotData(),
"responsibility": SlotData(),
"change": SlotData(),
},
"belief": {
"value": SlotData(),
"regret": SlotData(),
"pride": SlotData(),
"lesson": SlotData(),
},
}
def default_state() -> MemoirStateSchema:
return MemoirStateSchema(
stage_order=DEFAULT_STAGE_ORDER,
current_stage=DEFAULT_STAGE_ORDER[0],
covered_stages=[],
slots=default_slots(),
)