Files
life-echo/api/app/agents/memoir/processor.py
2026-03-19 10:38:11 +08:00

213 lines
7.3 KiB
Python

"""
回忆录后台处理器:分析对话、更新状态、生成章节、创意标题
使用 Celery 进行后台任务处理
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Dict, List
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
from app.core.task_tracker import task_tracker
from app.agents.state_schema import MemoirStateSchema
from app.agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
)
logger = get_logger(__name__)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
@dataclass
class AnalysisResult:
detected_stage: str
extracted_slots: Dict[str, str]
emotion: str
is_new_chapter: bool
class ContentAnalyzer:
def __init__(self) -> None:
self.llm = _get_langchain_llm()
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _fallback_slots(
self, state: MemoirStateSchema, stage: str, user_message: str
) -> Dict[str, str]:
stage_slots = state.slots.get(stage, {})
for key, value in stage_slots.items():
if not value.snippet:
return {key: user_message.strip()[:200]}
return {}
async def analyze_message(
self, user_message: str, current_state: MemoirStateSchema
) -> AnalysisResult:
detected_stage = self._detect_stage(
user_message, current_state.current_stage
)
extracted_slots: Dict[str, str] = {}
emotion = "neutral"
is_new_chapter = False
if self.llm:
try:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_state.current_stage,
stage_slots=current_state.slots.get(detected_stage, {}),
)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
emotion = parsed.get("emotion", emotion)
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
except json.JSONDecodeError:
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
except Exception as e:
logger.error("分析消息失败: %s", e)
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
else:
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
return AnalysisResult(
detected_stage=detected_stage,
extracted_slots=extracted_slots,
emotion=emotion,
is_new_chapter=is_new_chapter,
)
class MemoirGenerator:
def __init__(self) -> None:
self.llm = _get_langchain_llm()
async def generate_chapter_title(
self, stage: str, slots: Dict[str, str], emotion: str
) -> str:
if not self.llm:
return f"{stage} 回忆"
try:
prompt = get_creative_title_prompt(
stage=stage, emotion=emotion, slots=slots
)
response = await self.llm.ainvoke(prompt)
return response.content.strip().strip('"')
except Exception as e:
logger.error("生成标题失败: %s", e)
return f"{stage} 回忆"
async def generate_narrative(
self,
stage: str,
slots: Dict[str, str],
new_content: str,
existing_content: str,
) -> str:
if not self.llm:
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
try:
prompt = get_narrative_prompt(
stage=stage,
slots=slots,
new_content=new_content,
existing_content=existing_content,
)
response = await self.llm.ainvoke(prompt)
return response.content.strip()
except Exception as e:
logger.error("生成叙事失败: %s", e)
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
class BackgroundTaskRunner:
def __init__(self, debounce_seconds: int = 5) -> None:
self.debounce_seconds = debounce_seconds
self._pending: Dict[str, List[str]] = {}
self._timers: Dict[str, object] = {}
self.analyzer = ContentAnalyzer()
self.generator = MemoirGenerator()
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
try:
from app.tasks.memoir_tasks import process_memoir_segments
result = process_memoir_segments.delay(user_id, segment_ids)
task_id = result.id
await task_tracker.add_task(user_id, task_id, "memoir")
logger.info(
"已提交 Celery 任务: user_id=%s, task_id=%s, segments=%s",
user_id,
task_id,
len(segment_ids),
)
return task_id
except Exception as e:
logger.error("提交 Celery 任务失败: %s", e)
return None
async def queue_message(self, user_id: str, segment_id: str) -> None:
import asyncio
self._pending.setdefault(user_id, []).append(segment_id)
if user_id in self._timers:
self._timers[user_id].cancel()
async def delayed_submit():
try:
await asyncio.sleep(self.debounce_seconds)
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
await self._submit_task(user_id, segment_ids)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error("延迟提交任务失败: %s", e)
self._timers[user_id] = asyncio.create_task(delayed_submit())
async def flush_pending(self, user_id: str) -> str | None:
if user_id in self._timers:
self._timers[user_id].cancel()
del self._timers[user_id]
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
return await self._submit_task(user_id, segment_ids)
return None