""" 共享状态 Schema(对话 Agent 与后台 Agent 共用) """ from __future__ import annotations from typing import Dict, List, Optional from pydantic import BaseModel, Field from app.agents.stage_constants import CHAT_STAGES class SlotData(BaseModel): """Slot 数据结构""" snippet: Optional[str] = None segment_ids: List[str] = Field(default_factory=list) class KnownFact(BaseModel): """会话级已知事实:供 prompt 明确声明“不要再问这些”。""" label: str value: str source: str = "" stage: str = "" slot_name: str | None = None def prompt_line(self) -> str: prefix = f"{self.label}:".strip(":") if prefix: return f"{prefix} {self.value}".strip() return self.value.strip() class PersonaThread(BaseModel): """跨轮人物主线:用于持续呼应用户的稳定特质与动机。""" trait: str evidence: str = "" source: str = "" stage: str = "" def prompt_line(self) -> str: if self.evidence: return f"{self.trait}(依据:{self.evidence})" return self.trait class MemoirStateSchema(BaseModel): """回忆录状态""" stage_order: List[str] current_stage: str covered_stages: List[str] slots: Dict[str, Dict[str, SlotData]] known_facts: List[KnownFact] = Field(default_factory=list) persona_threads: List[PersonaThread] = Field(default_factory=list) recent_questions: List[str] = Field(default_factory=list) def empty_slots_for_current_stage(self) -> List[str]: stage_slots = self.slots.get(self.current_stage, {}) empty_keys: List[str] = [] for key, value in stage_slots.items(): if not value.snippet: empty_keys.append(key) return empty_keys def prompt_empty_slots_for_stage(self, stage: str) -> List[str]: """生成 prompt 时可追问的槽位,排除已被 known_facts 覆盖的方向。""" blocked = { fact.slot_name for fact in self.known_facts if fact.slot_name and (not fact.stage or fact.stage == stage) } return [key for key in self.empty_slots_for_stage(stage) if key not in blocked] def prompt_empty_slots_for_current_stage(self) -> List[str]: return self.prompt_empty_slots_for_stage(self.current_stage) def empty_slots_for_stage(self, stage: str) -> List[str]: """获取指定阶段的空槽位""" stage_slots = self.slots.get(stage, {}) return [key for key, value in stage_slots.items() if not value.snippet] def filled_slots_for_stage(self, stage: str) -> Dict[str, str]: """获取指定阶段已填充的槽位及其内容""" stage_slots = self.slots.get(stage, {}) return { key: value.snippet for key, value in stage_slots.items() if value.snippet } def all_stages_coverage(self) -> Dict[str, Dict]: """获取所有阶段的覆盖情况摘要""" coverage: Dict[str, Dict] = {} for stage in self.stage_order: stage_slots = self.slots.get(stage, {}) total = len(stage_slots) filled = sum(1 for v in stage_slots.values() if v.snippet) coverage[stage] = { "total": total, "filled": filled, "empty": total - filled, "ratio": filled / total if total > 0 else 0, } return coverage def prompt_known_fact_lines(self, *, limit: int = 10) -> List[str]: xs: List[str] = [] for fact in self.known_facts[-limit:]: line = fact.prompt_line().strip() if line: xs.append(line) return xs def prompt_persona_thread_lines(self, *, limit: int = 6) -> List[str]: xs: List[str] = [] for item in self.persona_threads[-limit:]: line = item.prompt_line().strip() if line: xs.append(line) return xs def prompt_recent_question_lines(self, *, limit: int = 4) -> List[str]: out: List[str] = [] seen: set[str] = set() for item in self.recent_questions[-limit:]: s = str(item).strip() if not s or s in seen: continue seen.add(s) out.append(s) return out # 与 stage_constants.CHAT_STAGES 同一顺序;list() 避免与元组共享可变别名 DEFAULT_STAGE_ORDER: list[str] = list(CHAT_STAGES) def default_slots() -> Dict[str, Dict[str, SlotData]]: return { "childhood": { "place": SlotData(), "people": SlotData(), "daily_life": SlotData(), "emotion": SlotData(), "turning_event": SlotData(), }, "education": { "school": SlotData(), "city": SlotData(), "motivation": SlotData(), "challenge": SlotData(), "change": SlotData(), }, "career": { "job": SlotData(), "environment": SlotData(), "decision": SlotData(), "pressure": SlotData(), "growth": SlotData(), }, "family": { "relationship": SlotData(), "conflict": SlotData(), "support": SlotData(), "responsibility": SlotData(), "change": SlotData(), }, "belief": { "value": SlotData(), "regret": SlotData(), "pride": SlotData(), "lesson": SlotData(), }, } def default_state() -> MemoirStateSchema: return MemoirStateSchema( stage_order=DEFAULT_STAGE_ORDER, current_stage=DEFAULT_STAGE_ORDER[0], covered_stages=[], slots=default_slots(), )