91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
|
|
"""conversation.lineage_schemas 单元测试。"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from types import SimpleNamespace
|
||
|
|
|
||
|
|
from app.features.conversation.lineage_schemas import (
|
||
|
|
DialogueLineage,
|
||
|
|
aggregate_lineage_from_segments,
|
||
|
|
merge_dialogue_lineages,
|
||
|
|
primary_user_message_id_from_lineage,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def test_for_single_turn_sets_primary() -> None:
|
||
|
|
ln = DialogueLineage.for_single_turn(
|
||
|
|
conversation_id="c1",
|
||
|
|
user_message_id="u1",
|
||
|
|
assistant_message_id="a1",
|
||
|
|
segment_ids=["s1"],
|
||
|
|
)
|
||
|
|
assert ln.primary_user_message_id == "u1"
|
||
|
|
assert ln.turns[0].assistant_message_id == "a1"
|
||
|
|
assert ln.segment_ids == ["s1"]
|
||
|
|
|
||
|
|
|
||
|
|
def test_primary_user_message_id_from_lineage_dict() -> None:
|
||
|
|
d = {
|
||
|
|
"schema_version": 1,
|
||
|
|
"conversation_id": "c",
|
||
|
|
"turns": [{"user_message_id": "x", "assistant_message_id": None}],
|
||
|
|
}
|
||
|
|
assert primary_user_message_id_from_lineage(d) == "x"
|
||
|
|
|
||
|
|
|
||
|
|
def test_merge_dialogue_lineages_dedupes_user_messages() -> None:
|
||
|
|
a = DialogueLineage.for_single_turn(
|
||
|
|
conversation_id="c1",
|
||
|
|
user_message_id="u1",
|
||
|
|
assistant_message_id="a1",
|
||
|
|
segment_ids=["s1"],
|
||
|
|
)
|
||
|
|
b = DialogueLineage.for_single_turn(
|
||
|
|
conversation_id="c1",
|
||
|
|
user_message_id="u1",
|
||
|
|
assistant_message_id="a2",
|
||
|
|
segment_ids=["s2"],
|
||
|
|
)
|
||
|
|
c = DialogueLineage.for_single_turn(
|
||
|
|
conversation_id="c1",
|
||
|
|
user_message_id="u2",
|
||
|
|
assistant_message_id="a3",
|
||
|
|
segment_ids=["s3"],
|
||
|
|
)
|
||
|
|
m = merge_dialogue_lineages([a, b, c], conversation_id="c1")
|
||
|
|
assert m is not None
|
||
|
|
assert [t.user_message_id for t in m.turns] == ["u1", "u2"]
|
||
|
|
|
||
|
|
|
||
|
|
def test_aggregate_lineage_from_segments_orders_and_merges() -> None:
|
||
|
|
segs = [
|
||
|
|
SimpleNamespace(
|
||
|
|
id="s2",
|
||
|
|
conversation_id="conv",
|
||
|
|
lineage_json=DialogueLineage.for_single_turn(
|
||
|
|
conversation_id="conv",
|
||
|
|
user_message_id="u2",
|
||
|
|
assistant_message_id="a2",
|
||
|
|
segment_ids=["s2"],
|
||
|
|
).model_dump(mode="json"),
|
||
|
|
user_message_id=None,
|
||
|
|
),
|
||
|
|
SimpleNamespace(
|
||
|
|
id="s1",
|
||
|
|
conversation_id="conv",
|
||
|
|
lineage_json=DialogueLineage.for_single_turn(
|
||
|
|
conversation_id="conv",
|
||
|
|
user_message_id="u1",
|
||
|
|
assistant_message_id="a1",
|
||
|
|
segment_ids=["s1"],
|
||
|
|
).model_dump(mode="json"),
|
||
|
|
user_message_id=None,
|
||
|
|
),
|
||
|
|
]
|
||
|
|
ordered = ["s1", "s2"]
|
||
|
|
order_map = {sid: i for i, sid in enumerate(ordered)}
|
||
|
|
segs.sort(key=lambda s: order_map[str(s.id)])
|
||
|
|
out = aggregate_lineage_from_segments(segs, conversation_id_fallback="conv")
|
||
|
|
assert out is not None
|
||
|
|
assert [t["user_message_id"] for t in out["turns"]] == ["u1", "u2"]
|