259 lines
9.5 KiB
Python
259 lines
9.5 KiB
Python
|
|
"""
|
|||
|
|
回忆录后台处理器
|
|||
|
|
|
|||
|
|
负责:
|
|||
|
|
- 分析用户对话内容,提取关键信息
|
|||
|
|
- 更新回忆录状态(slots)
|
|||
|
|
- 生成/更新章节内容
|
|||
|
|
- 创建创意章节标题
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import uuid
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Dict, List, Optional
|
|||
|
|
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
|
|||
|
|
from agents.state_schema import MemoirStateSchema
|
|||
|
|
from database.database import AsyncSessionLocal
|
|||
|
|
from database.models import Book, Chapter, Segment
|
|||
|
|
from services.llm_service import llm_service
|
|||
|
|
from services.memoir_state_service import (
|
|||
|
|
get_or_create_state,
|
|||
|
|
get_empty_slots,
|
|||
|
|
mark_stage_complete,
|
|||
|
|
switch_stage,
|
|||
|
|
update_slot,
|
|||
|
|
)
|
|||
|
|
from .prompts.memory_prompts import (
|
|||
|
|
get_creative_title_prompt,
|
|||
|
|
get_narrative_prompt,
|
|||
|
|
get_state_extraction_prompt,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
STAGE_KEYWORDS = {
|
|||
|
|
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
|||
|
|
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
|||
|
|
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
|||
|
|
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
|||
|
|
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AnalysisResult:
|
|||
|
|
detected_stage: str
|
|||
|
|
extracted_slots: Dict[str, str]
|
|||
|
|
emotion: str
|
|||
|
|
is_new_chapter: bool
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ContentAnalyzer:
|
|||
|
|
"""对话内容分析"""
|
|||
|
|
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self.llm = llm_service.get_llm()
|
|||
|
|
|
|||
|
|
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
|
|||
|
|
message = user_message.lower()
|
|||
|
|
for stage, keywords in STAGE_KEYWORDS.items():
|
|||
|
|
if any(word in message for word in keywords):
|
|||
|
|
return stage
|
|||
|
|
return fallback_stage
|
|||
|
|
|
|||
|
|
def _fallback_slots(self, state: MemoirStateSchema, stage: str, user_message: str) -> Dict[str, str]:
|
|||
|
|
stage_slots = state.slots.get(stage, {})
|
|||
|
|
for key, value in stage_slots.items():
|
|||
|
|
if not value.snippet:
|
|||
|
|
return {key: user_message.strip()[:200]}
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
async def analyze_message(self, user_message: str, current_state: MemoirStateSchema) -> AnalysisResult:
|
|||
|
|
detected_stage = self._detect_stage(user_message, current_state.current_stage)
|
|||
|
|
extracted_slots: Dict[str, str] = {}
|
|||
|
|
emotion = "neutral"
|
|||
|
|
is_new_chapter = False
|
|||
|
|
|
|||
|
|
if self.llm:
|
|||
|
|
prompt = get_state_extraction_prompt(
|
|||
|
|
user_message=user_message,
|
|||
|
|
current_stage=current_state.current_stage,
|
|||
|
|
stage_slots=current_state.slots.get(detected_stage, {}),
|
|||
|
|
)
|
|||
|
|
# 使用异步调用避免阻塞
|
|||
|
|
response = await asyncio.get_event_loop().run_in_executor(
|
|||
|
|
None, lambda: self.llm.invoke(prompt)
|
|||
|
|
)
|
|||
|
|
content = response.content.strip()
|
|||
|
|
try:
|
|||
|
|
parsed = json.loads(content)
|
|||
|
|
detected_stage = parsed.get("detected_stage", detected_stage)
|
|||
|
|
extracted_slots = parsed.get("slots", {}) or {}
|
|||
|
|
emotion = parsed.get("emotion", emotion)
|
|||
|
|
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
|
|||
|
|
else:
|
|||
|
|
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
|
|||
|
|
|
|||
|
|
return AnalysisResult(
|
|||
|
|
detected_stage=detected_stage,
|
|||
|
|
extracted_slots=extracted_slots,
|
|||
|
|
emotion=emotion,
|
|||
|
|
is_new_chapter=is_new_chapter,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MemoirGenerator:
|
|||
|
|
"""回忆录生成与更新"""
|
|||
|
|
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self.llm = llm_service.get_llm()
|
|||
|
|
|
|||
|
|
async def generate_chapter_title(self, stage: str, slots: Dict[str, str], emotion: str) -> str:
|
|||
|
|
if not self.llm:
|
|||
|
|
return f"{stage} 回忆"
|
|||
|
|
prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots)
|
|||
|
|
# 使用异步调用避免阻塞
|
|||
|
|
response = await asyncio.get_event_loop().run_in_executor(
|
|||
|
|
None, lambda: self.llm.invoke(prompt)
|
|||
|
|
)
|
|||
|
|
return response.content.strip().strip('"')
|
|||
|
|
|
|||
|
|
async def generate_narrative(self, stage: str, slots: Dict[str, str], new_content: str, existing_content: str) -> str:
|
|||
|
|
if not self.llm:
|
|||
|
|
if existing_content:
|
|||
|
|
return f"{existing_content}\n\n{new_content}"
|
|||
|
|
return new_content
|
|||
|
|
prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content)
|
|||
|
|
# 使用异步调用避免阻塞
|
|||
|
|
response = await asyncio.get_event_loop().run_in_executor(
|
|||
|
|
None, lambda: self.llm.invoke(prompt)
|
|||
|
|
)
|
|||
|
|
return response.content.strip()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BackgroundTaskRunner:
|
|||
|
|
"""后台任务调度(去抖)"""
|
|||
|
|
|
|||
|
|
def __init__(self, debounce_seconds: int = 5) -> None:
|
|||
|
|
self.debounce_seconds = debounce_seconds
|
|||
|
|
self.pending_tasks: Dict[str, List[str]] = {}
|
|||
|
|
self._scheduled: Dict[str, asyncio.Task] = {}
|
|||
|
|
self.analyzer = ContentAnalyzer()
|
|||
|
|
self.generator = MemoirGenerator()
|
|||
|
|
|
|||
|
|
async def queue_message(self, user_id: str, segment_id: str) -> None:
|
|||
|
|
self.pending_tasks.setdefault(user_id, []).append(segment_id)
|
|||
|
|
if user_id in self._scheduled:
|
|||
|
|
self._scheduled[user_id].cancel()
|
|||
|
|
self._scheduled[user_id] = asyncio.create_task(self._debounced_process(user_id))
|
|||
|
|
|
|||
|
|
async def _debounced_process(self, user_id: str) -> None:
|
|||
|
|
try:
|
|||
|
|
await asyncio.sleep(self.debounce_seconds)
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
return
|
|||
|
|
async with AsyncSessionLocal() as db:
|
|||
|
|
await self.process_pending(user_id, db)
|
|||
|
|
|
|||
|
|
async def process_pending(self, user_id: str, db) -> None:
|
|||
|
|
segment_ids = self.pending_tasks.pop(user_id, [])
|
|||
|
|
if not segment_ids:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
stmt = select(Segment).where(Segment.id.in_(segment_ids))
|
|||
|
|
result = await db.execute(stmt)
|
|||
|
|
segments = result.scalars().all()
|
|||
|
|
if not segments:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
state = await get_or_create_state(user_id, db)
|
|||
|
|
stage_to_segments: Dict[str, List[Segment]] = {}
|
|||
|
|
for segment in segments:
|
|||
|
|
analysis = await self.analyzer.analyze_message(segment.transcript_text, state)
|
|||
|
|
detected_stage = analysis.detected_stage
|
|||
|
|
if detected_stage != state.current_stage:
|
|||
|
|
state = await switch_stage(user_id, detected_stage, db)
|
|||
|
|
|
|||
|
|
for slot_name, snippet in analysis.extracted_slots.items():
|
|||
|
|
state = await update_slot(
|
|||
|
|
user_id=user_id,
|
|||
|
|
stage=detected_stage,
|
|||
|
|
slot_name=slot_name,
|
|||
|
|
snippet=snippet,
|
|||
|
|
segment_ids=[segment.id],
|
|||
|
|
db=db,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
stage_to_segments.setdefault(detected_stage, []).append(segment)
|
|||
|
|
|
|||
|
|
for stage, stage_segments in stage_to_segments.items():
|
|||
|
|
segment_texts = [seg.transcript_text for seg in stage_segments]
|
|||
|
|
combined_text = "\n\n".join(segment_texts)
|
|||
|
|
source_ids = [seg.id for seg in stage_segments]
|
|||
|
|
|
|||
|
|
stmt_chapter = select(Chapter).where(
|
|||
|
|
Chapter.user_id == user_id,
|
|||
|
|
Chapter.category == stage,
|
|||
|
|
)
|
|||
|
|
result_chapter = await db.execute(stmt_chapter)
|
|||
|
|
chapter = result_chapter.scalar_one_or_none()
|
|||
|
|
slot_snippets = {
|
|||
|
|
key: value.snippet for key, value in (state.slots.get(stage, {}) or {}).items() if value.snippet
|
|||
|
|
}
|
|||
|
|
title = chapter.title if chapter else await self.generator.generate_chapter_title(stage, slot_snippets, "neutral")
|
|||
|
|
existing_content = chapter.content if chapter else ""
|
|||
|
|
narrative = await self.generator.generate_narrative(stage, slot_snippets, combined_text, existing_content)
|
|||
|
|
|
|||
|
|
if chapter:
|
|||
|
|
chapter.content = narrative
|
|||
|
|
chapter.title = title
|
|||
|
|
chapter.is_new = True
|
|||
|
|
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
|
|||
|
|
else:
|
|||
|
|
chapter = Chapter(
|
|||
|
|
id=str(uuid.uuid4()),
|
|||
|
|
user_id=user_id,
|
|||
|
|
title=title,
|
|||
|
|
content=narrative,
|
|||
|
|
order_index=999,
|
|||
|
|
status="completed",
|
|||
|
|
category=stage,
|
|||
|
|
images=[],
|
|||
|
|
is_new=True,
|
|||
|
|
source_segments=source_ids,
|
|||
|
|
)
|
|||
|
|
db.add(chapter)
|
|||
|
|
|
|||
|
|
await db.flush()
|
|||
|
|
|
|||
|
|
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
|
|||
|
|
result_book = await db.execute(stmt_book)
|
|||
|
|
book = result_book.scalar_one_or_none()
|
|||
|
|
if not book:
|
|||
|
|
book = Book(
|
|||
|
|
id=str(uuid.uuid4()),
|
|||
|
|
user_id=user_id,
|
|||
|
|
title="我的回忆录",
|
|||
|
|
total_pages=0,
|
|||
|
|
total_words=0,
|
|||
|
|
cover_image_url=None,
|
|||
|
|
)
|
|||
|
|
db.add(book)
|
|||
|
|
book.has_update = True
|
|||
|
|
book.last_update_chapter_id = chapter.id
|
|||
|
|
|
|||
|
|
empty_slots = await get_empty_slots(user_id, db)
|
|||
|
|
if not empty_slots:
|
|||
|
|
await mark_stage_complete(user_id, state.current_stage, db)
|
|||
|
|
|
|||
|
|
for seg in segments:
|
|||
|
|
seg.processed = True
|
|||
|
|
|
|||
|
|
await db.commit()
|