- 访谈:新增 interview_state_hints,联动 orchestrator 与提示词 - 回忆录:story_pipeline_sync/state/memory/post_commit 与 Celery 任务调整 - 基建:开发用 celery broker、compose/development 脚本、依赖注入 - eval-web:移除数据集/实验/版本等页面与流式轮询,突出 Playground - 文档与单测同步
187 lines
5.7 KiB
Python
187 lines
5.7 KiB
Python
"""
|
||
共享状态 Schema(对话 Agent 与后台 Agent 共用)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from typing import Dict, List, Optional
|
||
|
||
from pydantic import BaseModel, Field
|
||
|
||
from app.agents.stage_constants import CHAT_STAGES
|
||
|
||
|
||
class SlotData(BaseModel):
|
||
"""Slot 数据结构"""
|
||
|
||
snippet: Optional[str] = None
|
||
segment_ids: List[str] = Field(default_factory=list)
|
||
|
||
|
||
class KnownFact(BaseModel):
|
||
"""会话级已知事实:供 prompt 明确声明“不要再问这些”。"""
|
||
|
||
label: str
|
||
value: str
|
||
source: str = ""
|
||
stage: str = ""
|
||
slot_name: str | None = None
|
||
|
||
def prompt_line(self) -> str:
|
||
prefix = f"{self.label}:".strip(":")
|
||
if prefix:
|
||
return f"{prefix} {self.value}".strip()
|
||
return self.value.strip()
|
||
|
||
|
||
class PersonaThread(BaseModel):
|
||
"""跨轮人物主线:用于持续呼应用户的稳定特质与动机。"""
|
||
|
||
trait: str
|
||
evidence: str = ""
|
||
source: str = ""
|
||
stage: str = ""
|
||
|
||
def prompt_line(self) -> str:
|
||
if self.evidence:
|
||
return f"{self.trait}(依据:{self.evidence})"
|
||
return self.trait
|
||
|
||
|
||
class MemoirStateSchema(BaseModel):
|
||
"""回忆录状态"""
|
||
|
||
stage_order: List[str]
|
||
current_stage: str
|
||
covered_stages: List[str]
|
||
slots: Dict[str, Dict[str, SlotData]]
|
||
known_facts: List[KnownFact] = Field(default_factory=list)
|
||
persona_threads: List[PersonaThread] = Field(default_factory=list)
|
||
recent_questions: List[str] = Field(default_factory=list)
|
||
|
||
def empty_slots_for_current_stage(self) -> List[str]:
|
||
stage_slots = self.slots.get(self.current_stage, {})
|
||
empty_keys: List[str] = []
|
||
for key, value in stage_slots.items():
|
||
if not value.snippet:
|
||
empty_keys.append(key)
|
||
return empty_keys
|
||
|
||
def prompt_empty_slots_for_stage(self, stage: str) -> List[str]:
|
||
"""生成 prompt 时可追问的槽位,排除已被 known_facts 覆盖的方向。"""
|
||
blocked = {
|
||
fact.slot_name
|
||
for fact in self.known_facts
|
||
if fact.slot_name and (not fact.stage or fact.stage == stage)
|
||
}
|
||
return [key for key in self.empty_slots_for_stage(stage) if key not in blocked]
|
||
|
||
def prompt_empty_slots_for_current_stage(self) -> List[str]:
|
||
return self.prompt_empty_slots_for_stage(self.current_stage)
|
||
|
||
def empty_slots_for_stage(self, stage: str) -> List[str]:
|
||
"""获取指定阶段的空槽位"""
|
||
stage_slots = self.slots.get(stage, {})
|
||
return [key for key, value in stage_slots.items() if not value.snippet]
|
||
|
||
def filled_slots_for_stage(self, stage: str) -> Dict[str, str]:
|
||
"""获取指定阶段已填充的槽位及其内容"""
|
||
stage_slots = self.slots.get(stage, {})
|
||
return {
|
||
key: value.snippet for key, value in stage_slots.items() if value.snippet
|
||
}
|
||
|
||
def all_stages_coverage(self) -> Dict[str, Dict]:
|
||
"""获取所有阶段的覆盖情况摘要"""
|
||
coverage: Dict[str, Dict] = {}
|
||
for stage in self.stage_order:
|
||
stage_slots = self.slots.get(stage, {})
|
||
total = len(stage_slots)
|
||
filled = sum(1 for v in stage_slots.values() if v.snippet)
|
||
coverage[stage] = {
|
||
"total": total,
|
||
"filled": filled,
|
||
"empty": total - filled,
|
||
"ratio": filled / total if total > 0 else 0,
|
||
}
|
||
return coverage
|
||
|
||
def prompt_known_fact_lines(self, *, limit: int = 10) -> List[str]:
|
||
xs: List[str] = []
|
||
for fact in self.known_facts[-limit:]:
|
||
line = fact.prompt_line().strip()
|
||
if line:
|
||
xs.append(line)
|
||
return xs
|
||
|
||
def prompt_persona_thread_lines(self, *, limit: int = 6) -> List[str]:
|
||
xs: List[str] = []
|
||
for item in self.persona_threads[-limit:]:
|
||
line = item.prompt_line().strip()
|
||
if line:
|
||
xs.append(line)
|
||
return xs
|
||
|
||
def prompt_recent_question_lines(self, *, limit: int = 4) -> List[str]:
|
||
out: List[str] = []
|
||
seen: set[str] = set()
|
||
for item in self.recent_questions[-limit:]:
|
||
s = str(item).strip()
|
||
if not s or s in seen:
|
||
continue
|
||
seen.add(s)
|
||
out.append(s)
|
||
return out
|
||
|
||
|
||
# 与 stage_constants.CHAT_STAGES 同一顺序;list() 避免与元组共享可变别名
|
||
DEFAULT_STAGE_ORDER: list[str] = list(CHAT_STAGES)
|
||
|
||
|
||
def default_slots() -> Dict[str, Dict[str, SlotData]]:
|
||
return {
|
||
"childhood": {
|
||
"place": SlotData(),
|
||
"people": SlotData(),
|
||
"daily_life": SlotData(),
|
||
"emotion": SlotData(),
|
||
"turning_event": SlotData(),
|
||
},
|
||
"education": {
|
||
"school": SlotData(),
|
||
"city": SlotData(),
|
||
"motivation": SlotData(),
|
||
"challenge": SlotData(),
|
||
"change": SlotData(),
|
||
},
|
||
"career": {
|
||
"job": SlotData(),
|
||
"environment": SlotData(),
|
||
"decision": SlotData(),
|
||
"pressure": SlotData(),
|
||
"growth": SlotData(),
|
||
},
|
||
"family": {
|
||
"relationship": SlotData(),
|
||
"conflict": SlotData(),
|
||
"support": SlotData(),
|
||
"responsibility": SlotData(),
|
||
"change": SlotData(),
|
||
},
|
||
"belief": {
|
||
"value": SlotData(),
|
||
"regret": SlotData(),
|
||
"pride": SlotData(),
|
||
"lesson": SlotData(),
|
||
},
|
||
}
|
||
|
||
|
||
def default_state() -> MemoirStateSchema:
|
||
return MemoirStateSchema(
|
||
stage_order=DEFAULT_STAGE_ORDER,
|
||
current_stage=DEFAULT_STAGE_ORDER[0],
|
||
covered_stages=[],
|
||
slots=default_slots(),
|
||
)
|