Files
life-echo/api/app/agents/state_schema.py
Kevin 064ad2161d refactor(eval+memoir):精简内部评测路由与服务,composite/对话摘要与 judge 能力补强
- 访谈:新增 interview_state_hints,联动 orchestrator 与提示词
- 回忆录:story_pipeline_sync/state/memory/post_commit 与 Celery 任务调整
- 基建:开发用 celery broker、compose/development 脚本、依赖注入
- eval-web:移除数据集/实验/版本等页面与流式轮询,突出 Playground
- 文档与单测同步
2026-04-08 21:36:12 +08:00

187 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
共享状态 Schema对话 Agent 与后台 Agent 共用)
"""
from __future__ import annotations
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from app.agents.stage_constants import CHAT_STAGES
class SlotData(BaseModel):
"""Slot 数据结构"""
snippet: Optional[str] = None
segment_ids: List[str] = Field(default_factory=list)
class KnownFact(BaseModel):
"""会话级已知事实:供 prompt 明确声明“不要再问这些”。"""
label: str
value: str
source: str = ""
stage: str = ""
slot_name: str | None = None
def prompt_line(self) -> str:
prefix = f"{self.label}".strip("")
if prefix:
return f"{prefix} {self.value}".strip()
return self.value.strip()
class PersonaThread(BaseModel):
"""跨轮人物主线:用于持续呼应用户的稳定特质与动机。"""
trait: str
evidence: str = ""
source: str = ""
stage: str = ""
def prompt_line(self) -> str:
if self.evidence:
return f"{self.trait}(依据:{self.evidence}"
return self.trait
class MemoirStateSchema(BaseModel):
"""回忆录状态"""
stage_order: List[str]
current_stage: str
covered_stages: List[str]
slots: Dict[str, Dict[str, SlotData]]
known_facts: List[KnownFact] = Field(default_factory=list)
persona_threads: List[PersonaThread] = Field(default_factory=list)
recent_questions: List[str] = Field(default_factory=list)
def empty_slots_for_current_stage(self) -> List[str]:
stage_slots = self.slots.get(self.current_stage, {})
empty_keys: List[str] = []
for key, value in stage_slots.items():
if not value.snippet:
empty_keys.append(key)
return empty_keys
def prompt_empty_slots_for_stage(self, stage: str) -> List[str]:
"""生成 prompt 时可追问的槽位,排除已被 known_facts 覆盖的方向。"""
blocked = {
fact.slot_name
for fact in self.known_facts
if fact.slot_name and (not fact.stage or fact.stage == stage)
}
return [key for key in self.empty_slots_for_stage(stage) if key not in blocked]
def prompt_empty_slots_for_current_stage(self) -> List[str]:
return self.prompt_empty_slots_for_stage(self.current_stage)
def empty_slots_for_stage(self, stage: str) -> List[str]:
"""获取指定阶段的空槽位"""
stage_slots = self.slots.get(stage, {})
return [key for key, value in stage_slots.items() if not value.snippet]
def filled_slots_for_stage(self, stage: str) -> Dict[str, str]:
"""获取指定阶段已填充的槽位及其内容"""
stage_slots = self.slots.get(stage, {})
return {
key: value.snippet for key, value in stage_slots.items() if value.snippet
}
def all_stages_coverage(self) -> Dict[str, Dict]:
"""获取所有阶段的覆盖情况摘要"""
coverage: Dict[str, Dict] = {}
for stage in self.stage_order:
stage_slots = self.slots.get(stage, {})
total = len(stage_slots)
filled = sum(1 for v in stage_slots.values() if v.snippet)
coverage[stage] = {
"total": total,
"filled": filled,
"empty": total - filled,
"ratio": filled / total if total > 0 else 0,
}
return coverage
def prompt_known_fact_lines(self, *, limit: int = 10) -> List[str]:
xs: List[str] = []
for fact in self.known_facts[-limit:]:
line = fact.prompt_line().strip()
if line:
xs.append(line)
return xs
def prompt_persona_thread_lines(self, *, limit: int = 6) -> List[str]:
xs: List[str] = []
for item in self.persona_threads[-limit:]:
line = item.prompt_line().strip()
if line:
xs.append(line)
return xs
def prompt_recent_question_lines(self, *, limit: int = 4) -> List[str]:
out: List[str] = []
seen: set[str] = set()
for item in self.recent_questions[-limit:]:
s = str(item).strip()
if not s or s in seen:
continue
seen.add(s)
out.append(s)
return out
# 与 stage_constants.CHAT_STAGES 同一顺序list() 避免与元组共享可变别名
DEFAULT_STAGE_ORDER: list[str] = list(CHAT_STAGES)
def default_slots() -> Dict[str, Dict[str, SlotData]]:
return {
"childhood": {
"place": SlotData(),
"people": SlotData(),
"daily_life": SlotData(),
"emotion": SlotData(),
"turning_event": SlotData(),
},
"education": {
"school": SlotData(),
"city": SlotData(),
"motivation": SlotData(),
"challenge": SlotData(),
"change": SlotData(),
},
"career": {
"job": SlotData(),
"environment": SlotData(),
"decision": SlotData(),
"pressure": SlotData(),
"growth": SlotData(),
},
"family": {
"relationship": SlotData(),
"conflict": SlotData(),
"support": SlotData(),
"responsibility": SlotData(),
"change": SlotData(),
},
"belief": {
"value": SlotData(),
"regret": SlotData(),
"pride": SlotData(),
"lesson": SlotData(),
},
}
def default_state() -> MemoirStateSchema:
return MemoirStateSchema(
stage_order=DEFAULT_STAGE_ORDER,
current_stage=DEFAULT_STAGE_ORDER[0],
covered_stages=[],
slots=default_slots(),
)