Files
life-echo/api/app/agents/state_schema.py
2026-03-19 14:36:40 +08:00

112 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
共享状态 Schema对话 Agent 与后台 Agent 共用)
"""
from __future__ import annotations
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
class SlotData(BaseModel):
"""Slot 数据结构"""
snippet: Optional[str] = None
segment_ids: List[str] = Field(default_factory=list)
class MemoirStateSchema(BaseModel):
"""回忆录状态"""
stage_order: List[str]
current_stage: str
covered_stages: List[str]
slots: Dict[str, Dict[str, SlotData]]
def empty_slots_for_current_stage(self) -> List[str]:
stage_slots = self.slots.get(self.current_stage, {})
empty_keys: List[str] = []
for key, value in stage_slots.items():
if not value.snippet:
empty_keys.append(key)
return empty_keys
def empty_slots_for_stage(self, stage: str) -> List[str]:
"""获取指定阶段的空槽位"""
stage_slots = self.slots.get(stage, {})
return [key for key, value in stage_slots.items() if not value.snippet]
def filled_slots_for_stage(self, stage: str) -> Dict[str, str]:
"""获取指定阶段已填充的槽位及其内容"""
stage_slots = self.slots.get(stage, {})
return {
key: value.snippet for key, value in stage_slots.items() if value.snippet
}
def all_stages_coverage(self) -> Dict[str, Dict]:
"""获取所有阶段的覆盖情况摘要"""
coverage: Dict[str, Dict] = {}
for stage in self.stage_order:
stage_slots = self.slots.get(stage, {})
total = len(stage_slots)
filled = sum(1 for v in stage_slots.values() if v.snippet)
coverage[stage] = {
"total": total,
"filled": filled,
"empty": total - filled,
"ratio": filled / total if total > 0 else 0,
}
return coverage
DEFAULT_STAGE_ORDER = ["childhood", "education", "career", "family", "belief"]
def default_slots() -> Dict[str, Dict[str, SlotData]]:
return {
"childhood": {
"place": SlotData(),
"people": SlotData(),
"daily_life": SlotData(),
"emotion": SlotData(),
"turning_event": SlotData(),
},
"education": {
"school": SlotData(),
"city": SlotData(),
"motivation": SlotData(),
"challenge": SlotData(),
"change": SlotData(),
},
"career": {
"job": SlotData(),
"environment": SlotData(),
"decision": SlotData(),
"pressure": SlotData(),
"growth": SlotData(),
},
"family": {
"relationship": SlotData(),
"conflict": SlotData(),
"support": SlotData(),
"responsibility": SlotData(),
"change": SlotData(),
},
"belief": {
"value": SlotData(),
"regret": SlotData(),
"pride": SlotData(),
"lesson": SlotData(),
},
}
def default_state() -> MemoirStateSchema:
return MemoirStateSchema(
stage_order=DEFAULT_STAGE_ORDER,
current_stage=DEFAULT_STAGE_ORDER[0],
covered_stages=[],
slots=default_slots(),
)