96 lines
2.6 KiB
Python
96 lines
2.6 KiB
Python
"""Phase1 批处理 LLM 分块:大量 segment 时拆多次请求并合并 by_id。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.agents.memoir.batch_phase1_prep import (
|
|
BatchPhase1SegmentRow,
|
|
run_batch_phase1_prep_chunked,
|
|
)
|
|
from app.agents.state_schema import MemoirStateSchema
|
|
|
|
|
|
def _state() -> MemoirStateSchema:
|
|
return MemoirStateSchema(
|
|
stage_order=["childhood"],
|
|
current_stage="childhood",
|
|
covered_stages=[],
|
|
slots={},
|
|
)
|
|
|
|
|
|
def test_run_batch_phase1_prep_chunked_splits_95_into_four_calls(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
chunk_lengths: list[int] = []
|
|
|
|
def fake_prep(
|
|
segments: list,
|
|
state: MemoirStateSchema,
|
|
llm: object,
|
|
) -> dict[str, BatchPhase1SegmentRow]:
|
|
chunk_lengths.append(len(segments))
|
|
return {
|
|
str(s.id): BatchPhase1SegmentRow(
|
|
detected_stage="childhood",
|
|
slots={},
|
|
chapter_category_raw="summary",
|
|
)
|
|
for s in segments
|
|
}
|
|
|
|
monkeypatch.setattr(
|
|
"app.agents.memoir.batch_phase1_prep.run_batch_phase1_prep",
|
|
fake_prep,
|
|
)
|
|
segments = [SimpleNamespace(id=f"s{i}", user_input_text="hello") for i in range(95)]
|
|
by_id = run_batch_phase1_prep_chunked(
|
|
segments,
|
|
_state(),
|
|
MagicMock(),
|
|
chunk_size=24,
|
|
)
|
|
assert len(by_id) == 95
|
|
assert chunk_lengths == [24, 24, 24, 23]
|
|
|
|
|
|
def test_chunked_bisect_on_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""块内失败时二分重试,仍能拼回全量 id。"""
|
|
chunk_lengths: list[int] = []
|
|
|
|
def fake_prep(
|
|
segments: list,
|
|
state: MemoirStateSchema,
|
|
llm: object,
|
|
) -> dict[str, BatchPhase1SegmentRow]:
|
|
chunk_lengths.append(len(segments))
|
|
if len(segments) == 4:
|
|
raise ValueError("simulate length limit")
|
|
return {
|
|
str(s.id): BatchPhase1SegmentRow(
|
|
detected_stage="childhood",
|
|
slots={},
|
|
chapter_category_raw="summary",
|
|
)
|
|
for s in segments
|
|
}
|
|
|
|
monkeypatch.setattr(
|
|
"app.agents.memoir.batch_phase1_prep.run_batch_phase1_prep",
|
|
fake_prep,
|
|
)
|
|
segments = [SimpleNamespace(id=f"b{i}", user_input_text="x") for i in range(4)]
|
|
by_id = run_batch_phase1_prep_chunked(
|
|
segments,
|
|
_state(),
|
|
MagicMock(),
|
|
chunk_size=100,
|
|
)
|
|
assert len(by_id) == 4
|
|
assert chunk_lengths[0] == 4
|
|
assert 2 in chunk_lengths
|