Files
life-echo/api/tests/test_batch_phase1_chunked.py

96 lines
2.6 KiB
Python
Raw Normal View History

"""Phase1 批处理 LLM 分块:大量 segment 时拆多次请求并合并 by_id。"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from app.agents.memoir.batch_phase1_prep import (
BatchPhase1SegmentRow,
run_batch_phase1_prep_chunked,
)
from app.agents.state_schema import MemoirStateSchema
def _state() -> MemoirStateSchema:
return MemoirStateSchema(
stage_order=["childhood"],
current_stage="childhood",
covered_stages=[],
slots={},
)
def test_run_batch_phase1_prep_chunked_splits_95_into_four_calls(
monkeypatch: pytest.MonkeyPatch,
) -> None:
chunk_lengths: list[int] = []
def fake_prep(
segments: list,
state: MemoirStateSchema,
llm: object,
) -> dict[str, BatchPhase1SegmentRow]:
chunk_lengths.append(len(segments))
return {
str(s.id): BatchPhase1SegmentRow(
detected_stage="childhood",
slots={},
chapter_category_raw="summary",
)
for s in segments
}
monkeypatch.setattr(
"app.agents.memoir.batch_phase1_prep.run_batch_phase1_prep",
fake_prep,
)
segments = [SimpleNamespace(id=f"s{i}", user_input_text="hello") for i in range(95)]
by_id = run_batch_phase1_prep_chunked(
segments,
_state(),
MagicMock(),
chunk_size=24,
)
assert len(by_id) == 95
assert chunk_lengths == [24, 24, 24, 23]
def test_chunked_bisect_on_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
"""块内失败时二分重试,仍能拼回全量 id。"""
chunk_lengths: list[int] = []
def fake_prep(
segments: list,
state: MemoirStateSchema,
llm: object,
) -> dict[str, BatchPhase1SegmentRow]:
chunk_lengths.append(len(segments))
if len(segments) == 4:
raise ValueError("simulate length limit")
return {
str(s.id): BatchPhase1SegmentRow(
detected_stage="childhood",
slots={},
chapter_category_raw="summary",
)
for s in segments
}
monkeypatch.setattr(
"app.agents.memoir.batch_phase1_prep.run_batch_phase1_prep",
fake_prep,
)
segments = [SimpleNamespace(id=f"b{i}", user_input_text="x") for i in range(4)]
by_id = run_batch_phase1_prep_chunked(
segments,
_state(),
MagicMock(),
chunk_size=100,
)
assert len(by_id) == 4
assert chunk_lengths[0] == 4
assert 2 in chunk_lengths