82 lines
2.2 KiB
Python
82 lines
2.2 KiB
Python
"""
|
||
共享状态 Schema(对话 Agent 与后台 Agent 共用)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from typing import Dict, List, Optional
|
||
|
||
from pydantic import BaseModel, Field
|
||
|
||
|
||
class SlotData(BaseModel):
|
||
"""Slot 数据结构"""
|
||
snippet: Optional[str] = None
|
||
segment_ids: List[str] = Field(default_factory=list)
|
||
|
||
|
||
class MemoirStateSchema(BaseModel):
|
||
"""回忆录状态"""
|
||
stage_order: List[str]
|
||
current_stage: str
|
||
covered_stages: List[str]
|
||
slots: Dict[str, Dict[str, SlotData]]
|
||
|
||
def empty_slots_for_current_stage(self) -> List[str]:
|
||
stage_slots = self.slots.get(self.current_stage, {})
|
||
empty_keys: List[str] = []
|
||
for key, value in stage_slots.items():
|
||
if not value.snippet:
|
||
empty_keys.append(key)
|
||
return empty_keys
|
||
|
||
|
||
DEFAULT_STAGE_ORDER = ["childhood", "education", "career", "family", "belief"]
|
||
|
||
|
||
def default_slots() -> Dict[str, Dict[str, SlotData]]:
|
||
return {
|
||
"childhood": {
|
||
"place": SlotData(),
|
||
"people": SlotData(),
|
||
"daily_life": SlotData(),
|
||
"emotion": SlotData(),
|
||
"turning_event": SlotData(),
|
||
},
|
||
"education": {
|
||
"school": SlotData(),
|
||
"city": SlotData(),
|
||
"motivation": SlotData(),
|
||
"challenge": SlotData(),
|
||
"change": SlotData(),
|
||
},
|
||
"career": {
|
||
"job": SlotData(),
|
||
"environment": SlotData(),
|
||
"decision": SlotData(),
|
||
"pressure": SlotData(),
|
||
"growth": SlotData(),
|
||
},
|
||
"family": {
|
||
"relationship": SlotData(),
|
||
"conflict": SlotData(),
|
||
"support": SlotData(),
|
||
"responsibility": SlotData(),
|
||
"change": SlotData(),
|
||
},
|
||
"belief": {
|
||
"value": SlotData(),
|
||
"regret": SlotData(),
|
||
"pride": SlotData(),
|
||
"lesson": SlotData(),
|
||
},
|
||
}
|
||
|
||
|
||
def default_state() -> MemoirStateSchema:
|
||
return MemoirStateSchema(
|
||
stage_order=DEFAULT_STAGE_ORDER,
|
||
current_stage=DEFAULT_STAGE_ORDER[0],
|
||
covered_stages=[],
|
||
slots=default_slots(),
|
||
)
|