Files
life-echo/api/agents/memoir_processor.py
penghanyuan 44bd478c1e agent init
2026-01-21 22:31:09 +01:00

259 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录后台处理器
负责:
- 分析用户对话内容,提取关键信息
- 更新回忆录状态slots
- 生成/更新章节内容
- 创建创意章节标题
"""
from __future__ import annotations
import asyncio
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional
from sqlalchemy import select
from agents.state_schema import MemoirStateSchema
from database.database import AsyncSessionLocal
from database.models import Book, Chapter, Segment
from services.llm_service import llm_service
from services.memoir_state_service import (
get_or_create_state,
get_empty_slots,
mark_stage_complete,
switch_stage,
update_slot,
)
from .prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
@dataclass
class AnalysisResult:
detected_stage: str
extracted_slots: Dict[str, str]
emotion: str
is_new_chapter: bool
class ContentAnalyzer:
"""对话内容分析"""
def __init__(self) -> None:
self.llm = llm_service.get_llm()
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _fallback_slots(self, state: MemoirStateSchema, stage: str, user_message: str) -> Dict[str, str]:
stage_slots = state.slots.get(stage, {})
for key, value in stage_slots.items():
if not value.snippet:
return {key: user_message.strip()[:200]}
return {}
async def analyze_message(self, user_message: str, current_state: MemoirStateSchema) -> AnalysisResult:
detected_stage = self._detect_stage(user_message, current_state.current_stage)
extracted_slots: Dict[str, str] = {}
emotion = "neutral"
is_new_chapter = False
if self.llm:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_state.current_stage,
stage_slots=current_state.slots.get(detected_stage, {}),
)
# 使用异步调用避免阻塞
response = await asyncio.get_event_loop().run_in_executor(
None, lambda: self.llm.invoke(prompt)
)
content = response.content.strip()
try:
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
emotion = parsed.get("emotion", emotion)
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
except json.JSONDecodeError:
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
else:
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
return AnalysisResult(
detected_stage=detected_stage,
extracted_slots=extracted_slots,
emotion=emotion,
is_new_chapter=is_new_chapter,
)
class MemoirGenerator:
"""回忆录生成与更新"""
def __init__(self) -> None:
self.llm = llm_service.get_llm()
async def generate_chapter_title(self, stage: str, slots: Dict[str, str], emotion: str) -> str:
if not self.llm:
return f"{stage} 回忆"
prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots)
# 使用异步调用避免阻塞
response = await asyncio.get_event_loop().run_in_executor(
None, lambda: self.llm.invoke(prompt)
)
return response.content.strip().strip('"')
async def generate_narrative(self, stage: str, slots: Dict[str, str], new_content: str, existing_content: str) -> str:
if not self.llm:
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content)
# 使用异步调用避免阻塞
response = await asyncio.get_event_loop().run_in_executor(
None, lambda: self.llm.invoke(prompt)
)
return response.content.strip()
class BackgroundTaskRunner:
"""后台任务调度(去抖)"""
def __init__(self, debounce_seconds: int = 5) -> None:
self.debounce_seconds = debounce_seconds
self.pending_tasks: Dict[str, List[str]] = {}
self._scheduled: Dict[str, asyncio.Task] = {}
self.analyzer = ContentAnalyzer()
self.generator = MemoirGenerator()
async def queue_message(self, user_id: str, segment_id: str) -> None:
self.pending_tasks.setdefault(user_id, []).append(segment_id)
if user_id in self._scheduled:
self._scheduled[user_id].cancel()
self._scheduled[user_id] = asyncio.create_task(self._debounced_process(user_id))
async def _debounced_process(self, user_id: str) -> None:
try:
await asyncio.sleep(self.debounce_seconds)
except asyncio.CancelledError:
return
async with AsyncSessionLocal() as db:
await self.process_pending(user_id, db)
async def process_pending(self, user_id: str, db) -> None:
segment_ids = self.pending_tasks.pop(user_id, [])
if not segment_ids:
return
stmt = select(Segment).where(Segment.id.in_(segment_ids))
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
return
state = await get_or_create_state(user_id, db)
stage_to_segments: Dict[str, List[Segment]] = {}
for segment in segments:
analysis = await self.analyzer.analyze_message(segment.transcript_text, state)
detected_stage = analysis.detected_stage
if detected_stage != state.current_stage:
state = await switch_stage(user_id, detected_stage, db)
for slot_name, snippet in analysis.extracted_slots.items():
state = await update_slot(
user_id=user_id,
stage=detected_stage,
slot_name=slot_name,
snippet=snippet,
segment_ids=[segment.id],
db=db,
)
stage_to_segments.setdefault(detected_stage, []).append(segment)
for stage, stage_segments in stage_to_segments.items():
segment_texts = [seg.transcript_text for seg in stage_segments]
combined_text = "\n\n".join(segment_texts)
source_ids = [seg.id for seg in stage_segments]
stmt_chapter = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
)
result_chapter = await db.execute(stmt_chapter)
chapter = result_chapter.scalar_one_or_none()
slot_snippets = {
key: value.snippet for key, value in (state.slots.get(stage, {}) or {}).items() if value.snippet
}
title = chapter.title if chapter else await self.generator.generate_chapter_title(stage, slot_snippets, "neutral")
existing_content = chapter.content if chapter else ""
narrative = await self.generator.generate_narrative(stage, slot_snippets, combined_text, existing_content)
if chapter:
chapter.content = narrative
chapter.title = title
chapter.is_new = True
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
else:
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
content=narrative,
order_index=999,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=source_ids,
)
db.add(chapter)
await db.flush()
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
result_book = await db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
empty_slots = await get_empty_slots(user_id, db)
if not empty_slots:
await mark_stage_complete(user_id, state.current_stage, db)
for seg in segments:
seg.processed = True
await db.commit()