Merge branch 'feat/improve-agent-prompt'
This commit is contained in:
@@ -14,7 +14,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database.database import SessionLocal
|
||||
from database.models import Book, Chapter, Segment, MemoirState
|
||||
from database.models import Book, Chapter, Segment, MemoirState, User
|
||||
from services.llm_service import llm_service
|
||||
from agents.state_schema import MemoirStateSchema, SlotData, default_state
|
||||
from agents.prompts.memory_prompts import (
|
||||
@@ -23,6 +23,7 @@ from agents.prompts.memory_prompts import (
|
||||
get_state_extraction_prompt,
|
||||
STAGE_TO_ORDER,
|
||||
)
|
||||
from agents.prompts.profile_prompts import format_user_profile_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -179,9 +180,21 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
logger.warning(f"未找到段落: {segment_ids}")
|
||||
return {"status": "no_segments"}
|
||||
|
||||
# 获取用户状态
|
||||
# 获取用户状态和资料
|
||||
state = _get_or_create_state_sync(user_id, db)
|
||||
llm = llm_service.get_llm()
|
||||
|
||||
user_obj = db.get(User, user_id)
|
||||
user_profile = ""
|
||||
user_birth_year = None
|
||||
if user_obj:
|
||||
user_birth_year = user_obj.birth_year
|
||||
user_profile = format_user_profile_context(
|
||||
birth_year=user_obj.birth_year,
|
||||
birth_place=user_obj.birth_place,
|
||||
grew_up_place=user_obj.grew_up_place,
|
||||
occupation=user_obj.occupation,
|
||||
)
|
||||
|
||||
# 按阶段分组处理
|
||||
stage_to_segments: Dict[str, List[Segment]] = {}
|
||||
@@ -257,7 +270,9 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
title_prompt = get_creative_title_prompt(
|
||||
stage=stage,
|
||||
emotion="neutral",
|
||||
slots=slot_snippets
|
||||
slots=slot_snippets,
|
||||
user_profile=user_profile,
|
||||
birth_year=user_birth_year,
|
||||
)
|
||||
title_response = llm.invoke(title_prompt)
|
||||
title = title_response.content.strip().strip('"')
|
||||
@@ -267,6 +282,8 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
slots=slot_snippets,
|
||||
new_content=combined_text,
|
||||
existing_content=existing_content,
|
||||
user_profile=user_profile,
|
||||
birth_year=user_birth_year,
|
||||
)
|
||||
narrative_response = llm.invoke(narrative_prompt)
|
||||
new_narrative = narrative_response.content.strip()
|
||||
|
||||
Reference in New Issue
Block a user