"""Phase1 批处理 LLM 分块:大量 segment 时拆多次请求并合并 by_id。""" from __future__ import annotations from types import SimpleNamespace from unittest.mock import MagicMock import pytest from app.agents.memoir.batch_phase1_prep import ( BatchPhase1SegmentRow, run_batch_phase1_prep_chunked, ) from app.agents.state_schema import MemoirStateSchema def _state() -> MemoirStateSchema: return MemoirStateSchema( stage_order=["childhood"], current_stage="childhood", covered_stages=[], slots={}, ) def test_run_batch_phase1_prep_chunked_splits_95_into_four_calls( monkeypatch: pytest.MonkeyPatch, ) -> None: chunk_lengths: list[int] = [] def fake_prep( segments: list, state: MemoirStateSchema, llm: object, *, language: str = "zh", ) -> dict[str, BatchPhase1SegmentRow]: chunk_lengths.append(len(segments)) return { str(s.id): BatchPhase1SegmentRow( detected_stage="childhood", slots={}, chapter_category_raw="summary", ) for s in segments } monkeypatch.setattr( "app.agents.memoir.batch_phase1_prep.run_batch_phase1_prep", fake_prep, ) segments = [SimpleNamespace(id=f"s{i}", user_input_text="hello") for i in range(95)] by_id = run_batch_phase1_prep_chunked( segments, _state(), MagicMock(), chunk_size=24, ) assert len(by_id) == 95 assert chunk_lengths == [24, 24, 24, 23] def test_chunked_bisect_on_value_error(monkeypatch: pytest.MonkeyPatch) -> None: """块内失败时二分重试,仍能拼回全量 id。""" chunk_lengths: list[int] = [] def fake_prep( segments: list, state: MemoirStateSchema, llm: object, *, language: str = "zh", ) -> dict[str, BatchPhase1SegmentRow]: chunk_lengths.append(len(segments)) if len(segments) == 4: raise ValueError("simulate length limit") return { str(s.id): BatchPhase1SegmentRow( detected_stage="childhood", slots={}, chapter_category_raw="summary", ) for s in segments } monkeypatch.setattr( "app.agents.memoir.batch_phase1_prep.run_batch_phase1_prep", fake_prep, ) segments = [SimpleNamespace(id=f"b{i}", user_input_text="x") for i in range(4)] by_id = run_batch_phase1_prep_chunked( segments, _state(), MagicMock(), chunk_size=100, ) assert len(by_id) == 4 assert chunk_lengths[0] == 4 assert 2 in chunk_lengths