183 lines
5.5 KiB
Python
183 lines
5.5 KiB
Python
|
|
"""Canonical dialogue lineage (message-grade provenance) for memory, story, chapter, eval."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Any, Sequence
|
||
|
|
|
||
|
|
from pydantic import BaseModel, Field, field_validator
|
||
|
|
|
||
|
|
LINEAGE_SCHEMA_VERSION = 1
|
||
|
|
|
||
|
|
|
||
|
|
class DialogueTurnRef(BaseModel):
|
||
|
|
"""One interview beat: user message paired with assistant reply (if any)."""
|
||
|
|
|
||
|
|
user_message_id: str
|
||
|
|
assistant_message_id: str | None = None
|
||
|
|
|
||
|
|
|
||
|
|
class DialogueLineage(BaseModel):
|
||
|
|
"""
|
||
|
|
Multi-turn provenance. `turns` order is chronological within the contributing slice.
|
||
|
|
"""
|
||
|
|
|
||
|
|
schema_version: int = Field(default=LINEAGE_SCHEMA_VERSION, ge=1)
|
||
|
|
conversation_id: str
|
||
|
|
turns: list[DialogueTurnRef] = Field(default_factory=list)
|
||
|
|
primary_user_message_id: str | None = None
|
||
|
|
segment_ids: list[str] = Field(default_factory=list)
|
||
|
|
|
||
|
|
@field_validator("turns")
|
||
|
|
@classmethod
|
||
|
|
def _non_empty_user_ids(cls, v: list[DialogueTurnRef]) -> list[DialogueTurnRef]:
|
||
|
|
for t in v:
|
||
|
|
if not (t.user_message_id or "").strip():
|
||
|
|
raise ValueError("turn.user_message_id must be non-empty")
|
||
|
|
return v
|
||
|
|
|
||
|
|
def model_dump_json_safe(self) -> dict[str, Any]:
|
||
|
|
return self.model_dump(mode="json")
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def for_single_turn(
|
||
|
|
cls,
|
||
|
|
*,
|
||
|
|
conversation_id: str,
|
||
|
|
user_message_id: str,
|
||
|
|
assistant_message_id: str | None,
|
||
|
|
segment_ids: list[str] | None = None,
|
||
|
|
) -> DialogueLineage:
|
||
|
|
return cls(
|
||
|
|
conversation_id=conversation_id,
|
||
|
|
turns=[
|
||
|
|
DialogueTurnRef(
|
||
|
|
user_message_id=user_message_id,
|
||
|
|
assistant_message_id=assistant_message_id,
|
||
|
|
)
|
||
|
|
],
|
||
|
|
primary_user_message_id=user_message_id,
|
||
|
|
segment_ids=list(segment_ids or []),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def parse_dialogue_lineage(raw: Any) -> DialogueLineage | None:
|
||
|
|
if raw is None:
|
||
|
|
return None
|
||
|
|
if isinstance(raw, DialogueLineage):
|
||
|
|
return raw
|
||
|
|
if not isinstance(raw, dict):
|
||
|
|
return None
|
||
|
|
try:
|
||
|
|
return DialogueLineage.model_validate(raw)
|
||
|
|
except Exception:
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def primary_user_message_id_from_lineage(raw: Any) -> str | None:
|
||
|
|
ln = parse_dialogue_lineage(raw)
|
||
|
|
if ln is None:
|
||
|
|
return None
|
||
|
|
if ln.primary_user_message_id:
|
||
|
|
return ln.primary_user_message_id
|
||
|
|
if ln.turns:
|
||
|
|
return ln.turns[0].user_message_id
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def merge_dialogue_lineages(
|
||
|
|
lineages: Sequence[DialogueLineage | dict | None],
|
||
|
|
*,
|
||
|
|
conversation_id: str,
|
||
|
|
segment_ids_ordered: list[str] | None = None,
|
||
|
|
) -> DialogueLineage | None:
|
||
|
|
"""Ordered union of turns; dedupe by user_message_id (first occurrence wins)."""
|
||
|
|
turns_out: list[DialogueTurnRef] = []
|
||
|
|
seen_user: set[str] = set()
|
||
|
|
segments_accum: list[str] = []
|
||
|
|
|
||
|
|
for raw in lineages:
|
||
|
|
ln = parse_dialogue_lineage(raw)
|
||
|
|
if ln is None:
|
||
|
|
continue
|
||
|
|
for sid in ln.segment_ids:
|
||
|
|
if sid and sid not in segments_accum:
|
||
|
|
segments_accum.append(sid)
|
||
|
|
for t in ln.turns:
|
||
|
|
uid = t.user_message_id.strip()
|
||
|
|
if not uid or uid in seen_user:
|
||
|
|
continue
|
||
|
|
seen_user.add(uid)
|
||
|
|
turns_out.append(
|
||
|
|
DialogueTurnRef(
|
||
|
|
user_message_id=uid,
|
||
|
|
assistant_message_id=t.assistant_message_id,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
if segment_ids_ordered:
|
||
|
|
for sid in segment_ids_ordered:
|
||
|
|
if sid and sid not in segments_accum:
|
||
|
|
segments_accum.append(sid)
|
||
|
|
|
||
|
|
if not turns_out:
|
||
|
|
return None
|
||
|
|
|
||
|
|
return DialogueLineage(
|
||
|
|
conversation_id=conversation_id,
|
||
|
|
turns=turns_out,
|
||
|
|
primary_user_message_id=turns_out[0].user_message_id,
|
||
|
|
segment_ids=segments_accum,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def aggregate_lineage_from_segments(
|
||
|
|
segments: Sequence[Any],
|
||
|
|
*,
|
||
|
|
conversation_id_fallback: str | None = None,
|
||
|
|
) -> dict[str, Any] | None:
|
||
|
|
"""
|
||
|
|
Build merged lineage dict from ORM segments (expects .id, .conversation_id,
|
||
|
|
.lineage_json, optional .user_message_id).
|
||
|
|
"""
|
||
|
|
if not segments:
|
||
|
|
return None
|
||
|
|
conv0 = conversation_id_fallback or getattr(
|
||
|
|
segments[0], "conversation_id", None
|
||
|
|
) or ""
|
||
|
|
if not conv0:
|
||
|
|
lj0 = getattr(segments[0], "lineage_json", None)
|
||
|
|
if isinstance(lj0, dict) and lj0.get("conversation_id"):
|
||
|
|
conv0 = str(lj0["conversation_id"])
|
||
|
|
if not conv0:
|
||
|
|
return None
|
||
|
|
|
||
|
|
lineages: list[DialogueLineage | dict | None] = []
|
||
|
|
seg_ids_order: list[str] = []
|
||
|
|
for seg in segments:
|
||
|
|
sid = str(getattr(seg, "id", "") or "")
|
||
|
|
if sid:
|
||
|
|
seg_ids_order.append(sid)
|
||
|
|
lj = getattr(seg, "lineage_json", None)
|
||
|
|
if lj:
|
||
|
|
lineages.append(lj)
|
||
|
|
else:
|
||
|
|
um = getattr(seg, "user_message_id", None)
|
||
|
|
if um:
|
||
|
|
lineages.append(
|
||
|
|
DialogueLineage.for_single_turn(
|
||
|
|
conversation_id=str(
|
||
|
|
getattr(seg, "conversation_id", None) or conv0
|
||
|
|
),
|
||
|
|
user_message_id=str(um),
|
||
|
|
assistant_message_id=None,
|
||
|
|
segment_ids=[sid] if sid else None,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
merged = merge_dialogue_lineages(
|
||
|
|
lineages, conversation_id=str(conv0), segment_ids_ordered=seg_ids_order
|
||
|
|
)
|
||
|
|
if merged is None:
|
||
|
|
return None
|
||
|
|
return merged.model_dump_json_safe()
|