Files
life-echo/api/app/agents/memoir/orchestrator.py
2026-03-19 14:36:40 +08:00

124 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
MemoirOrchestrator按 segment 编排流水线,调用各 Specialist Agent。
负责:遍历 segments、按 category 聚合、调用 Specialist、更新 state
持久化与章节生成由 process_category 回调完成。
"""
from __future__ import annotations
from typing import Any, Callable, Dict, List, Set, Tuple
from app.core.logging import get_logger
from app.features.conversation.models import Segment
from app.agents.state_schema import MemoirStateSchema
from app.agents.memoir.extraction_agent import ExtractionAgent, ExtractionResult
from app.agents.memoir.classification_agent import (
ClassificationAgent,
_detect_stage as detect_stage_from_keywords,
)
logger = get_logger(__name__)
class MemoirOrchestrator:
"""
回忆录生成编排器。
遍历 segments → ExtractionAgent → ClassificationAgent → 按 category 聚合 →
调用 process_category 生成叙事并持久化。
"""
def __init__(self) -> None:
self.extraction_agent = ExtractionAgent()
self.classification_agent = ClassificationAgent()
def run(
self,
*,
segments: List[Segment],
llm: Any,
user_profile: str = "",
user_birth_year: Any = None,
get_or_create_state: Callable[[], MemoirStateSchema],
update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema],
acquire_lock: Callable[[str], bool],
release_lock: Callable[[str], None],
process_category: Callable[
[
str,
List[Segment],
MemoirStateSchema,
str,
Any,
Any,
],
Tuple[Any, bool],
],
raise_retry: Callable[[], None],
) -> Tuple[Set[str], int]:
"""
执行回忆录流水线。
process_category(category, segments, state, user_profile, user_birth_year, llm)
返回 (chapter, has_images_to_generate)。
返回 (chapters_to_enqueue, processed_count)。
raise_retry 用于锁竞争时抛出 Celery retry。
"""
state = get_or_create_state()
chapters_to_enqueue: Set[str] = set()
category_to_segments: Dict[str, List[Segment]] = {}
# 1) 遍历 segmentsExtractionAgent → 更新 slotsClassificationAgent → 聚合
for segment in segments:
text = segment.transcript_text or ""
# 关键词预检测阶段,用于 slot 查找(与原有逻辑一致)
initial_stage = detect_stage_from_keywords(
text, state.current_stage or "childhood"
)
stage_slots_raw = state.slots.get(initial_stage, {}) or {}
result: ExtractionResult = self.extraction_agent.extract(
user_message=text,
current_stage=state.current_stage or "childhood",
stage_slots=stage_slots_raw,
llm=llm,
)
detected_stage = result.detected_stage
for slot_name, snippet in result.slots.items():
state = update_slot(detected_stage, slot_name, snippet, [segment.id])
# ClassificationAgent
chapter_category = self.classification_agent.classify(
text=text,
fallback_stage=detected_stage,
llm=llm,
)
if chapter_category is None:
logger.info("段落无回忆录价值,跳过: segment_id=%s", segment.id)
continue
category_to_segments.setdefault(chapter_category, []).append(segment)
# 2) 按 category 调用 process_category内含 NarrativeAgent、PlaceholderInject、持久化
for chapter_category, category_segments in category_to_segments.items():
if not acquire_lock(chapter_category):
logger.warning(
"章节锁竞争: category=%s, 延迟重试",
chapter_category,
)
raise_retry()
try:
chapter, has_images = process_category(
chapter_category,
category_segments,
state,
user_profile,
user_birth_year,
llm,
)
if chapter and has_images:
chapters_to_enqueue.add(chapter.id)
finally:
release_lock(chapter_category)
return chapters_to_enqueue, len(segments)