Merge branch 'feat/improve-agent-prompt'
This commit is contained in:
24
.github/workflows/docker-build-deploy.yml
vendored
24
.github/workflows/docker-build-deploy.yml
vendored
@@ -7,7 +7,13 @@ on:
|
||||
paths:
|
||||
- 'api/**'
|
||||
- '.github/workflows/**'
|
||||
workflow_dispatch: # 允许手动触发
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: '部署分支'
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
env:
|
||||
IMAGE_NAME: lifecho-api
|
||||
@@ -24,6 +30,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.branch || github.ref }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -77,6 +85,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.branch || github.ref }}
|
||||
|
||||
- name: Set up SSH
|
||||
uses: webfactory/ssh-agent@v0.9.0
|
||||
@@ -91,11 +101,12 @@ jobs:
|
||||
- name: Determine image tag
|
||||
id: image_tag
|
||||
run: |
|
||||
if [ "${{ github.ref_name }}" == "main" ] || [ "${{ github.ref_name }}" == "master" ]; then
|
||||
DEPLOY_BRANCH="${{ github.event.inputs.branch || github.ref_name }}"
|
||||
echo "deploy_branch=$DEPLOY_BRANCH" >> $GITHUB_OUTPUT
|
||||
if [ "$DEPLOY_BRANCH" == "main" ] || [ "$DEPLOY_BRANCH" == "master" ]; then
|
||||
echo "tag=latest" >> $GITHUB_OUTPUT
|
||||
else
|
||||
# 将分支名中的斜杠替换为破折号,以符合 Docker 标签规范
|
||||
BRANCH_TAG=$(echo "${{ github.ref_name }}" | sed 's/\//-/g')
|
||||
BRANCH_TAG=$(echo "$DEPLOY_BRANCH" | sed 's/\//-/g')
|
||||
echo "tag=$BRANCH_TAG" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
@@ -214,6 +225,11 @@ jobs:
|
||||
"docker exec -i life-echo-postgres psql -U $DB_USER -d $DB_NAME" \
|
||||
< api/migrations/add_chapter_is_active.sql
|
||||
|
||||
echo "添加用户基础资料字段..."
|
||||
ssh -p $SSH_PORT $SSH_USER@$SSH_HOST \
|
||||
"docker exec -i life-echo-postgres psql -U $DB_USER -d $DB_NAME" \
|
||||
< api/migrations/add_user_profile_fields.sql
|
||||
|
||||
echo "数据库迁移完成"
|
||||
|
||||
- name: Verify deployment
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
||||
支持异步调用和 Redis 会话存储
|
||||
支持用户基础资料收集和时代背景融入
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
@@ -11,6 +13,13 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from services.llm_service import llm_service
|
||||
from services.redis_service import redis_service
|
||||
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
|
||||
from .prompts.profile_prompts import (
|
||||
get_profile_greeting_prompt,
|
||||
get_profile_extraction_prompt,
|
||||
get_profile_followup_prompt,
|
||||
format_user_profile_context,
|
||||
get_missing_profile_fields,
|
||||
)
|
||||
from .state_schema import MemoirStateSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -102,6 +111,87 @@ class ConversationAgent:
|
||||
logger.error(f"生成回应失败: {e}")
|
||||
return f"抱歉,生成回应时出现错误: {str(e)}"
|
||||
|
||||
async def generate_profile_greeting(
|
||||
self,
|
||||
conversation_id: str,
|
||||
missing_fields: List[str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
"""生成资料收集的开场白(首次对话时使用)"""
|
||||
if not self.llm:
|
||||
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
|
||||
|
||||
try:
|
||||
prompt = get_profile_greeting_prompt(missing_fields, nickname)
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error(f"生成资料收集开场白失败: {e}")
|
||||
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
|
||||
|
||||
async def extract_profile_from_message(self, user_message: str, missing_fields: List[str]) -> Dict[str, Any]:
|
||||
"""从用户消息中提取基础资料信息"""
|
||||
if not self.llm or not missing_fields:
|
||||
return {}
|
||||
|
||||
try:
|
||||
prompt = get_profile_extraction_prompt(user_message, missing_fields)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
parsed = json.loads(content)
|
||||
result = {}
|
||||
if "birth_year" in parsed and isinstance(parsed["birth_year"], int):
|
||||
result["birth_year"] = parsed["birth_year"]
|
||||
if "birth_place" in parsed and parsed["birth_place"]:
|
||||
result["birth_place"] = str(parsed["birth_place"])
|
||||
if "grew_up_place" in parsed and parsed["grew_up_place"]:
|
||||
result["grew_up_place"] = str(parsed["grew_up_place"])
|
||||
if "occupation" in parsed and parsed["occupation"]:
|
||||
result["occupation"] = str(parsed["occupation"])
|
||||
return result
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error(f"提取资料信息失败: {e}")
|
||||
return {}
|
||||
|
||||
async def generate_profile_followup(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
filled_fields: Dict[str, str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
"""在资料收集过程中生成跟进回复"""
|
||||
if not self.llm:
|
||||
return ["谢谢!还能告诉我更多吗?"]
|
||||
|
||||
try:
|
||||
prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname)
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error(f"生成资料跟进回复失败: {e}")
|
||||
return ["谢谢分享!能再告诉我一些吗?"]
|
||||
|
||||
def _detect_user_stage(self, user_message: str) -> str:
|
||||
"""
|
||||
通过关键词检测用户当前正在谈论的人生阶段。
|
||||
@@ -126,7 +216,8 @@ class ConversationAgent:
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
"""
|
||||
基于共享状态异步生成引导式回复
|
||||
@@ -135,6 +226,7 @@ class ConversationAgent:
|
||||
conversation_id: 对话 ID
|
||||
user_message: 用户消息
|
||||
memoir_state: 共享状态
|
||||
user_profile_context: 用户基础资料上下文
|
||||
|
||||
Returns:
|
||||
Agent 回应文本列表(支持多条消息)
|
||||
@@ -150,18 +242,11 @@ class ConversationAgent:
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
# 检测用户当前正在谈论的阶段
|
||||
detected_user_stage = self._detect_user_stage(user_message)
|
||||
|
||||
# 从 Redis 获取对话历史,用于计算对话轮数
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
conversation_turn = len(history_messages) // 2 # 每轮包括一个用户消息和一个AI回复
|
||||
|
||||
# 计算同一话题的轮数(简单估算:基于已填充槽位的变化)
|
||||
# 如果槽位数量没有增加,说明还在同一话题深入
|
||||
conversation_turn = len(history_messages) // 2
|
||||
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
|
||||
|
||||
# 获取所有阶段的覆盖情况
|
||||
all_stages_coverage = memoir_state.all_stages_coverage()
|
||||
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
@@ -173,24 +258,19 @@ class ConversationAgent:
|
||||
same_topic_turns=same_topic_turns,
|
||||
all_stages_coverage=all_stages_coverage,
|
||||
detected_user_stage=detected_user_stage,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
|
||||
history_string = self._format_history_string(history_messages)
|
||||
|
||||
# 构建完整 prompt
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
|
||||
# 异步调用 LLM
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 保存对话到 Redis
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
|
||||
# 支持多条消息,用 [SPLIT] 分隔
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
# 最多返回 3 条
|
||||
return messages[:3] if messages else [response_text]
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -19,6 +19,14 @@ from .memory_prompts import (
|
||||
CHAPTER_ORDER,
|
||||
STAGE_TO_ORDER,
|
||||
)
|
||||
from .profile_prompts import (
|
||||
get_profile_greeting_prompt,
|
||||
get_profile_extraction_prompt,
|
||||
get_profile_followup_prompt,
|
||||
format_user_profile_context,
|
||||
get_missing_profile_fields,
|
||||
PROFILE_FIELD_NAMES,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConversationStage",
|
||||
@@ -35,5 +43,11 @@ __all__ = [
|
||||
"CHAPTER_CATEGORIES",
|
||||
"CHAPTER_ORDER",
|
||||
"STAGE_TO_ORDER",
|
||||
"get_profile_greeting_prompt",
|
||||
"get_profile_extraction_prompt",
|
||||
"get_profile_followup_prompt",
|
||||
"format_user_profile_context",
|
||||
"get_missing_profile_fields",
|
||||
"PROFILE_FIELD_NAMES",
|
||||
]
|
||||
|
||||
|
||||
@@ -173,6 +173,73 @@ RESPONSE_STYLES = [
|
||||
]
|
||||
|
||||
|
||||
def _build_era_context(current_stage: str, user_profile_context: str) -> str:
|
||||
"""
|
||||
根据用户的人生阶段和出生年份,生成对应时代的历史/政治/文化背景提示。
|
||||
让 agent 在对话中自然融入时代感。
|
||||
"""
|
||||
if not user_profile_context:
|
||||
return ""
|
||||
|
||||
birth_year = None
|
||||
birth_place = ""
|
||||
for line in user_profile_context.split("\n"):
|
||||
if "出生年份" in line:
|
||||
try:
|
||||
birth_year = int(line.split(":")[1].strip().replace("年", ""))
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
if "出生地" in line or "成长地" in line:
|
||||
birth_place = line.split(":")[1].strip() if ":" in line else ""
|
||||
|
||||
if not birth_year:
|
||||
return ""
|
||||
|
||||
stage_era_map = {
|
||||
"childhood": (0, 12),
|
||||
"education": (6, 22),
|
||||
"career": (18, 50),
|
||||
"family": (20, 50),
|
||||
"belief": (30, 60),
|
||||
}
|
||||
|
||||
age_range = stage_era_map.get(current_stage, (0, 30))
|
||||
era_start = birth_year + age_range[0]
|
||||
era_end = birth_year + age_range[1]
|
||||
|
||||
era_events = []
|
||||
|
||||
decade_events = {
|
||||
1950: "新中国成立初期、土地改革、抗美援朝",
|
||||
1960: "大跃进、三年自然灾害、中苏关系变化",
|
||||
1970: "文化大革命、知青上山下乡、中美建交",
|
||||
1980: "改革开放、恢复高考、个体经济兴起、电视普及",
|
||||
1990: "社会主义市场经济、下海潮、香港回归、互联网初期",
|
||||
2000: "加入WTO、房地产兴起、手机普及、北京奥运",
|
||||
2010: "移动互联网爆发、微信时代、共享经济、双创浪潮",
|
||||
2020: "新冠疫情、直播经济、人工智能崛起",
|
||||
}
|
||||
|
||||
for decade, events in decade_events.items():
|
||||
if era_start <= decade + 9 and era_end >= decade:
|
||||
era_events.append(f"{decade}年代:{events}")
|
||||
|
||||
if not era_events:
|
||||
return ""
|
||||
|
||||
place_hint = f"(用户来自{birth_place})" if birth_place else ""
|
||||
|
||||
return f"""
|
||||
## 时代背景参考{place_hint}
|
||||
用户在这个人生阶段大约经历了 {era_start}-{era_end} 年({age_range[0]}-{age_range[1]} 岁):
|
||||
{";".join(era_events)}
|
||||
|
||||
你可以在对话中自然地提及这些时代元素来丰富提问,例如:
|
||||
- "那个年代好像正好赶上xxx,你们那边是什么情况?"
|
||||
- "听说那时候xxx特别流行,你有印象吗?"
|
||||
- 不要生硬地列举历史事件,而是像聊天一样自然带入"""
|
||||
|
||||
|
||||
def get_guided_conversation_prompt(
|
||||
current_stage: str,
|
||||
empty_slots: List[str],
|
||||
@@ -182,6 +249,7 @@ def get_guided_conversation_prompt(
|
||||
same_topic_turns: int = 0,
|
||||
all_stages_coverage: Dict[str, Dict] = None,
|
||||
detected_user_stage: str = "",
|
||||
user_profile_context: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
生成状态感知的对话提示词
|
||||
@@ -195,6 +263,7 @@ def get_guided_conversation_prompt(
|
||||
same_topic_turns: 同一话题的轮数
|
||||
all_stages_coverage: 所有阶段的覆盖情况 {stage: {total, filled, empty, ratio}}
|
||||
detected_user_stage: 检测到用户正在谈论的阶段(可能和 current_stage 不同)
|
||||
user_profile_context: 用户基础资料上下文
|
||||
"""
|
||||
stage_name_map = {
|
||||
"childhood": "童年时光",
|
||||
@@ -286,8 +355,16 @@ def get_guided_conversation_prompt(
|
||||
else:
|
||||
topic_desc = f"你们聊到了「{current_stage_name}」这个话题"
|
||||
|
||||
prompt = f"""你是「岁月知己」,用户的老朋友,正在和他/她聊人生故事。{topic_desc}。
|
||||
# --- 用户资料和时代背景 ---
|
||||
profile_section = ""
|
||||
if user_profile_context:
|
||||
profile_section = f"\n## 用户基本信息\n{user_profile_context}\n"
|
||||
|
||||
active_stage = detected_user_stage if user_jumped and detected_user_stage else current_stage
|
||||
era_context = _build_era_context(active_stage, user_profile_context)
|
||||
|
||||
prompt = f"""你是「岁月知己」,用户的老朋友,正在和他/她聊人生故事。{topic_desc}。
|
||||
{profile_section}
|
||||
## 已经聊到的内容({current_stage_name})
|
||||
{filled_slots_str}
|
||||
|
||||
@@ -296,7 +373,7 @@ def get_guided_conversation_prompt(
|
||||
|
||||
## 整体进度
|
||||
{progress_str}
|
||||
|
||||
{era_context}
|
||||
## 用户刚才说
|
||||
"{user_message}"
|
||||
|
||||
@@ -309,6 +386,7 @@ def get_guided_conversation_prompt(
|
||||
3. **保持自然**:不要每次都追问,有时候可以分享感受、表达好奇、或者轻松聊两句
|
||||
4. **适时引导**:跟着用户的节奏聊了几轮后,如果有自然的时机,可以温和地引向还没聊到的人生阶段,但绝不要生硬
|
||||
5. **追问要具体**:如果要追问,问具体的细节,比如"那时候是什么季节""身边有谁陪着你""当时心里什么感觉"
|
||||
6. **融入时代感**:如果有时代背景信息,在聊天中自然地提及当时的社会环境、流行文化、历史事件,让对话更有代入感和共鸣
|
||||
{dynamic_guidance}{uncovered_hint}
|
||||
|
||||
## 回复格式
|
||||
@@ -331,6 +409,8 @@ def get_guided_conversation_prompt(
|
||||
- "那个年代的xxx确实是这样"(理解)
|
||||
- "所以后来怎么样了?"(好奇)
|
||||
- "对了,你刚才提到xxx,那个时候..."(换话题)
|
||||
- "那会儿好像正赶上改革开放,你们那边变化大吗?"(时代融入)
|
||||
- "80年代初的xxx,你还有印象吗?"(时代细节)
|
||||
|
||||
直接输出你要说的话(多条消息用 [SPLIT] 分隔):"""
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
回忆录整理 Agent 提示词模板
|
||||
"""
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
# 章节分类映射
|
||||
CHAPTER_CATEGORIES = {
|
||||
@@ -177,42 +178,89 @@ def get_state_extraction_prompt(user_message: str, current_stage: str, stage_slo
|
||||
"""
|
||||
|
||||
|
||||
def get_creative_title_prompt(stage: str, emotion: str, slots: dict) -> str:
|
||||
"""生成有创意的章节标题"""
|
||||
def _build_age_hint(stage: str, birth_year: Optional[int] = None) -> str:
|
||||
"""根据人生阶段和出生年份推算大致年龄区间"""
|
||||
if not birth_year:
|
||||
return ""
|
||||
stage_age_ranges = {
|
||||
"childhood": (0, 12),
|
||||
"education": (6, 22),
|
||||
"career": (18, 60),
|
||||
"career_early": (18, 30),
|
||||
"career_achievement": (25, 55),
|
||||
"career_challenge": (20, 55),
|
||||
"family": (20, 60),
|
||||
"belief": (30, 70),
|
||||
"beliefs": (30, 70),
|
||||
"summary": (50, 80),
|
||||
}
|
||||
age_range = stage_age_ranges.get(stage)
|
||||
if not age_range:
|
||||
return ""
|
||||
year_start = birth_year + age_range[0]
|
||||
year_end = birth_year + age_range[1]
|
||||
return f"大约 {year_start}-{year_end} 年({age_range[0]}-{age_range[1]} 岁)"
|
||||
|
||||
|
||||
def get_creative_title_prompt(
|
||||
stage: str,
|
||||
emotion: str,
|
||||
slots: dict,
|
||||
user_profile: str = "",
|
||||
birth_year: Optional[int] = None,
|
||||
) -> str:
|
||||
"""生成有创意的章节标题,包含年龄/时间信息"""
|
||||
age_hint = _build_age_hint(stage, birth_year)
|
||||
profile_section = f"\n用户基本信息:\n{user_profile}" if user_profile else ""
|
||||
time_section = f"\n时间参考:{age_hint}" if age_hint else ""
|
||||
|
||||
return f"""{get_system_prompt()}
|
||||
|
||||
请根据阶段和情绪生成 1 个有创意的章节标题。
|
||||
阶段:{stage}
|
||||
情绪:{emotion}
|
||||
可用信息:{slots}
|
||||
可用信息:{slots}{profile_section}{time_section}
|
||||
|
||||
要求:
|
||||
1. 标题 12-18 字以内
|
||||
1. 标题格式:「时间标注 · 标题正文」
|
||||
- 时间标注用年龄或年代表示,如"6-12岁"、"1980年代"、"二十出头"
|
||||
- 标题正文 12-18 字以内
|
||||
2. 情绪 + 人生阶段 + 意象
|
||||
3. 示例风格:
|
||||
- 《那个夏天,我第一次离开家》
|
||||
- 《在陌生城市站稳脚跟》
|
||||
- 《不是所有选择都被理解》
|
||||
- 《慢下来,人生开始发声》
|
||||
- 《6-12岁 · 那条巷子尽头的蝉鸣》
|
||||
- 《18岁 · 第一次离开家的夏天》
|
||||
- 《25-35岁 · 在陌生城市站稳脚跟》
|
||||
- 《四十不惑 · 慢下来,人生开始发声》
|
||||
- 《1990年代 · 不是所有选择都被理解》
|
||||
|
||||
只输出标题文字,不要加引号或其他内容。
|
||||
只输出标题文字,不要加引号或书名号。
|
||||
"""
|
||||
|
||||
|
||||
def get_narrative_prompt(stage: str, slots: dict, new_content: str, existing_content: str = "") -> str:
|
||||
def get_narrative_prompt(
|
||||
stage: str,
|
||||
slots: dict,
|
||||
new_content: str,
|
||||
existing_content: str = "",
|
||||
user_profile: str = "",
|
||||
birth_year: Optional[int] = None,
|
||||
) -> str:
|
||||
"""将新对话改写为叙述(只输出新内容的改写,不重复已有内容)"""
|
||||
# 只取已有内容的末尾作为衔接上下文
|
||||
context_tail = ""
|
||||
if existing_content:
|
||||
context_tail = existing_content[-300:] if len(existing_content) > 300 else existing_content
|
||||
|
||||
context_section = f"\n\n【衔接上下文(已有内容的末尾,仅供参考衔接,不要重复)】:\n{context_tail}" if context_tail else ""
|
||||
|
||||
profile_section = f"\n\n用户基本信息:\n{user_profile}" if user_profile else ""
|
||||
age_hint = _build_age_hint(stage, birth_year)
|
||||
time_section = f"\n时间参考:{age_hint}" if age_hint else ""
|
||||
|
||||
return f"""{get_system_prompt()}
|
||||
|
||||
请将以下新的对话内容改写为第一人称文学叙述。
|
||||
阶段:{stage}
|
||||
可用信息:{slots}
|
||||
可用信息:{slots}{profile_section}{time_section}
|
||||
|
||||
新的对话内容:
|
||||
{new_content}
|
||||
@@ -225,6 +273,7 @@ def get_narrative_prompt(stage: str, slots: dict, new_content: str, existing_con
|
||||
4. 如果有衔接上下文,确保新内容与之自然衔接(语气、时间线连贯)
|
||||
5. 语气自然,有情绪
|
||||
6. 在适合配图的地方插入图片占位符
|
||||
7. 如果有用户的基本信息(出生地、成长地等),在叙述中自然融入地域文化和时代背景
|
||||
|
||||
## 图片占位符格式
|
||||
在描述场景、人物、重要时刻的段落后,插入图片占位符,格式为:
|
||||
|
||||
163
api/agents/prompts/profile_prompts.py
Normal file
163
api/agents/prompts/profile_prompts.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
用户基础资料收集提示词
|
||||
"""
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
PROFILE_FIELD_NAMES = {
|
||||
"birth_year": "出生年份",
|
||||
"birth_place": "出生地",
|
||||
"grew_up_place": "成长地",
|
||||
"occupation": "职业",
|
||||
}
|
||||
|
||||
|
||||
def get_profile_greeting_prompt(missing_fields: List[str], nickname: str = "") -> str:
|
||||
"""生成初次见面、收集基础资料的引导提示词"""
|
||||
missing_names = [PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES]
|
||||
missing_str = "、".join(missing_names)
|
||||
name_part = f",{nickname}" if nickname else ""
|
||||
|
||||
return f"""你是「岁月知己」,一位温暖真诚的人生故事访谈者。你正在和用户初次见面{name_part}。
|
||||
|
||||
在正式聊人生故事之前,你需要先了解一些基本信息。还需要了解的信息有:{missing_str}。
|
||||
|
||||
## 你的任务
|
||||
用自然、亲切的方式,像老朋友聊天一样,向用户询问这些基础信息。
|
||||
|
||||
## 规则
|
||||
1. 不要一次问所有问题,每次只问 1-2 个
|
||||
2. 如果用户已经在对话中提到了某些信息,不要重复问
|
||||
3. 用口语化、亲切的方式提问
|
||||
4. 当所有信息都收集完后,自然过渡到人生故事访谈
|
||||
|
||||
## 提问示例
|
||||
- "你是哪一年出生的呀?"
|
||||
- "你是在哪里出生的?小时候也是在那里长大的吗?"
|
||||
- "你现在是做什么工作的呀?或者之前主要从事什么职业?"
|
||||
|
||||
## 严格禁止
|
||||
- 禁止输出括号注释、思考过程
|
||||
- 禁止说"我需要收集信息"之类的机械话
|
||||
- 禁止一次列出所有问题
|
||||
|
||||
## 回复格式
|
||||
- 如果内容较多,可以用 [SPLIT] 分隔成多条消息
|
||||
- 像微信聊天一样自然
|
||||
|
||||
直接输出你要说的话:"""
|
||||
|
||||
|
||||
def get_profile_extraction_prompt(user_message: str, missing_fields: List[str]) -> str:
|
||||
"""从用户回答中提取基础资料信息"""
|
||||
missing_names = {f: PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES}
|
||||
|
||||
return f"""请从用户的回答中提取基础资料信息。
|
||||
|
||||
用户的回答:
|
||||
"{user_message}"
|
||||
|
||||
需要提取的字段(只提取确实提到的):
|
||||
{missing_names}
|
||||
|
||||
请返回 JSON 格式,只包含确实提到的字段:
|
||||
{{
|
||||
"birth_year": 1965,
|
||||
"birth_place": "湖南长沙",
|
||||
"grew_up_place": "湖南长沙",
|
||||
"occupation": "教师"
|
||||
}}
|
||||
|
||||
规则:
|
||||
1. birth_year 必须是整数(四位数年份),如"65年出生"应转为 1965
|
||||
2. 如果用户说"在老家长大"而之前提到了出生地,grew_up_place 可以和 birth_place 相同
|
||||
3. 只提取明确提到的信息,不要猜测
|
||||
4. 如果没有提取到任何信息,返回空对象 {{}}
|
||||
|
||||
只返回 JSON,不要其他内容。"""
|
||||
|
||||
|
||||
def get_profile_followup_prompt(
|
||||
missing_fields: List[str],
|
||||
filled_fields: Dict[str, str],
|
||||
user_message: str,
|
||||
nickname: str = "",
|
||||
) -> str:
|
||||
"""在收集资料过程中的跟进提问"""
|
||||
missing_names = [PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES]
|
||||
missing_str = "、".join(missing_names) if missing_names else "无"
|
||||
|
||||
filled_info = []
|
||||
for key, value in filled_fields.items():
|
||||
name = PROFILE_FIELD_NAMES.get(key, key)
|
||||
filled_info.append(f"{name}: {value}")
|
||||
filled_str = "\n".join(filled_info) if filled_info else "暂无"
|
||||
|
||||
if not missing_names:
|
||||
return f"""你是「岁月知己」。用户的基本信息已经收集完毕:
|
||||
{filled_str}
|
||||
|
||||
用户刚才说:"{user_message}"
|
||||
|
||||
请对用户的回答做出温暖的回应,然后自然地过渡到人生故事的访谈。
|
||||
可以说类似"了解了!那我们现在开始聊聊你的人生故事吧"这样的话,然后问一个关于童年的问题作为开场。
|
||||
|
||||
回复格式:多条消息用 [SPLIT] 分隔。
|
||||
直接输出你要说的话:"""
|
||||
|
||||
return f"""你是「岁月知己」,正在和用户聊天收集基本信息。
|
||||
|
||||
已知信息:
|
||||
{filled_str}
|
||||
|
||||
还需要了解:{missing_str}
|
||||
|
||||
用户刚才说:"{user_message}"
|
||||
|
||||
请先对用户说的内容做出自然回应,然后继续询问还未了解的信息(每次问 1-2 个)。
|
||||
语气要像朋友聊天一样自然亲切。
|
||||
|
||||
严格禁止:
|
||||
- 禁止输出括号注释、思考过程
|
||||
- 禁止说"我注意到""我需要了解"
|
||||
|
||||
回复格式:多条消息用 [SPLIT] 分隔。
|
||||
直接输出你要说的话:"""
|
||||
|
||||
|
||||
def format_user_profile_context(
|
||||
birth_year: Optional[int] = None,
|
||||
birth_place: Optional[str] = None,
|
||||
grew_up_place: Optional[str] = None,
|
||||
occupation: Optional[str] = None,
|
||||
) -> str:
|
||||
"""将用户基础信息格式化为上下文字符串,供其他 agent 使用"""
|
||||
parts = []
|
||||
if birth_year:
|
||||
parts.append(f"出生年份:{birth_year}年")
|
||||
if birth_place:
|
||||
parts.append(f"出生地:{birth_place}")
|
||||
if grew_up_place:
|
||||
parts.append(f"成长地:{grew_up_place}")
|
||||
if occupation:
|
||||
parts.append(f"职业:{occupation}")
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def get_missing_profile_fields(
|
||||
birth_year: Optional[int] = None,
|
||||
birth_place: Optional[str] = None,
|
||||
grew_up_place: Optional[str] = None,
|
||||
occupation: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""返回缺失的用户资料字段列表"""
|
||||
missing = []
|
||||
if not birth_year:
|
||||
missing.append("birth_year")
|
||||
if not birth_place:
|
||||
missing.append("birth_place")
|
||||
if not grew_up_place:
|
||||
missing.append("grew_up_place")
|
||||
if not occupation:
|
||||
missing.append("occupation")
|
||||
return missing
|
||||
@@ -30,6 +30,11 @@ class User(Base):
|
||||
subscription_expires_at = Column(DateTime(timezone=True), nullable=True) # 订阅到期时间
|
||||
created_at = Column(DateTime(timezone=True), default=utc_now)
|
||||
|
||||
birth_year = Column(Integer, nullable=True)
|
||||
birth_place = Column(String, nullable=True)
|
||||
grew_up_place = Column(String, nullable=True)
|
||||
occupation = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
conversations = relationship("Conversation", back_populates="user")
|
||||
chapters = relationship("Chapter", back_populates="user")
|
||||
|
||||
5
api/migrations/add_user_profile_fields.sql
Normal file
5
api/migrations/add_user_profile_fields.sql
Normal file
@@ -0,0 +1,5 @@
|
||||
-- 添加用户基础资料字段(出生年份、出生地、成长地、职业)
|
||||
ALTER TABLE users ADD COLUMN IF NOT EXISTS birth_year INTEGER;
|
||||
ALTER TABLE users ADD COLUMN IF NOT EXISTS birth_place VARCHAR;
|
||||
ALTER TABLE users ADD COLUMN IF NOT EXISTS grew_up_place VARCHAR;
|
||||
ALTER TABLE users ADD COLUMN IF NOT EXISTS occupation VARCHAR;
|
||||
@@ -28,6 +28,18 @@ class UserProfileResponse(BaseModel):
|
||||
avatar_url: Optional[str]
|
||||
subscription_type: str
|
||||
created_at: str
|
||||
birth_year: Optional[int] = None
|
||||
birth_place: Optional[str] = None
|
||||
grew_up_place: Optional[str] = None
|
||||
occupation: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateUserProfileRequest(BaseModel):
|
||||
"""更新用户基础资料请求"""
|
||||
birth_year: Optional[int] = None
|
||||
birth_place: Optional[str] = None
|
||||
grew_up_place: Optional[str] = None
|
||||
occupation: Optional[str] = None
|
||||
|
||||
|
||||
class TestSubscriptionRequest(BaseModel):
|
||||
@@ -59,7 +71,43 @@ async def get_user_profile(
|
||||
nickname=current_user.nickname,
|
||||
avatar_url=current_user.avatar_url,
|
||||
subscription_type=current_user.subscription_type,
|
||||
created_at=current_user.created_at.isoformat()
|
||||
created_at=current_user.created_at.isoformat(),
|
||||
birth_year=current_user.birth_year,
|
||||
birth_place=current_user.birth_place,
|
||||
grew_up_place=current_user.grew_up_place,
|
||||
occupation=current_user.occupation,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/profile", response_model=UserProfileResponse)
|
||||
async def update_user_profile(
|
||||
body: UpdateUserProfileRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
):
|
||||
"""更新用户基础资料(出生年份、出生地、成长地、职业)"""
|
||||
if body.birth_year is not None:
|
||||
current_user.birth_year = body.birth_year
|
||||
if body.birth_place is not None:
|
||||
current_user.birth_place = body.birth_place
|
||||
if body.grew_up_place is not None:
|
||||
current_user.grew_up_place = body.grew_up_place
|
||||
if body.occupation is not None:
|
||||
current_user.occupation = body.occupation
|
||||
await db.commit()
|
||||
await db.refresh(current_user)
|
||||
return UserProfileResponse(
|
||||
id=current_user.id,
|
||||
phone=current_user.phone,
|
||||
email=current_user.email,
|
||||
nickname=current_user.nickname,
|
||||
avatar_url=current_user.avatar_url,
|
||||
subscription_type=current_user.subscription_type,
|
||||
created_at=current_user.created_at.isoformat(),
|
||||
birth_year=current_user.birth_year,
|
||||
birth_place=current_user.birth_place,
|
||||
grew_up_place=current_user.grew_up_place,
|
||||
occupation=current_user.occupation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -161,6 +161,28 @@ async def websocket_endpoint(
|
||||
return
|
||||
|
||||
|
||||
# 首次连接时检查资料完整性,发送资料收集开场白
|
||||
missing_profile = _get_missing_profile_fields(user)
|
||||
if missing_profile:
|
||||
try:
|
||||
greetings = await manager.conversation_agent.generate_profile_greeting(
|
||||
conversation_id=conversation_id,
|
||||
missing_fields=missing_profile,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
import asyncio as _asyncio_greet
|
||||
for i, text in enumerate(greetings):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": text, "index": i, "total": len(greetings)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
if i < len(greetings) - 1:
|
||||
await _asyncio_greet.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
|
||||
|
||||
# 主循环:处理消息
|
||||
while True:
|
||||
try:
|
||||
@@ -206,7 +228,8 @@ async def websocket_endpoint(
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
manager=manager
|
||||
manager=manager,
|
||||
user=user,
|
||||
)
|
||||
|
||||
elif msg_type == MessageType.AUDIO_MESSAGE:
|
||||
@@ -267,7 +290,8 @@ async def websocket_endpoint(
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
manager=manager
|
||||
manager=manager,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
# 转写失败,发送错误消息
|
||||
@@ -380,58 +404,140 @@ async def websocket_endpoint(
|
||||
await manager.disconnect(conversation_id)
|
||||
|
||||
|
||||
def _get_missing_profile_fields(user: UserModel) -> list:
|
||||
"""检查用户缺失的资料字段"""
|
||||
from agents.prompts.profile_prompts import get_missing_profile_fields
|
||||
return get_missing_profile_fields(
|
||||
birth_year=user.birth_year,
|
||||
birth_place=user.birth_place,
|
||||
grew_up_place=user.grew_up_place,
|
||||
occupation=user.occupation,
|
||||
)
|
||||
|
||||
|
||||
def _get_filled_profile_fields(user: UserModel) -> dict:
|
||||
"""获取用户已有的资料字段(中文展示)"""
|
||||
from agents.prompts.profile_prompts import PROFILE_FIELD_NAMES
|
||||
filled = {}
|
||||
if user.birth_year:
|
||||
filled["birth_year"] = str(user.birth_year)
|
||||
if user.birth_place:
|
||||
filled["birth_place"] = user.birth_place
|
||||
if user.grew_up_place:
|
||||
filled["grew_up_place"] = user.grew_up_place
|
||||
if user.occupation:
|
||||
filled["occupation"] = user.occupation
|
||||
return filled
|
||||
|
||||
|
||||
async def _apply_extracted_profile(user: UserModel, extracted: dict, db: AsyncSession):
|
||||
"""将提取到的资料信息保存到用户模型"""
|
||||
changed = False
|
||||
if "birth_year" in extracted and not user.birth_year:
|
||||
user.birth_year = extracted["birth_year"]
|
||||
changed = True
|
||||
if "birth_place" in extracted and not user.birth_place:
|
||||
user.birth_place = extracted["birth_place"]
|
||||
changed = True
|
||||
if "grew_up_place" in extracted and not user.grew_up_place:
|
||||
user.grew_up_place = extracted["grew_up_place"]
|
||||
changed = True
|
||||
if "occupation" in extracted and not user.occupation:
|
||||
user.occupation = extracted["occupation"]
|
||||
changed = True
|
||||
if changed:
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
async def process_user_message(
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
conversation: Conversation,
|
||||
segment: Segment,
|
||||
db: AsyncSession,
|
||||
manager: ConnectionManager
|
||||
manager: ConnectionManager,
|
||||
user: UserModel = None,
|
||||
) -> None:
|
||||
"""
|
||||
处理用户消息,生成Agent回应(异步版本)
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
user_message: 用户消息文本
|
||||
conversation: 对话对象
|
||||
segment: 段落对象
|
||||
db: 数据库会话
|
||||
manager: 连接管理器
|
||||
|
||||
Returns:
|
||||
更新后的对话阶段
|
||||
支持资料收集模式和正式访谈模式
|
||||
"""
|
||||
import asyncio as _asyncio
|
||||
|
||||
|
||||
agent = manager.conversation_agent
|
||||
|
||||
# --- 资料收集模式 ---
|
||||
if user:
|
||||
missing = _get_missing_profile_fields(user)
|
||||
if missing:
|
||||
try:
|
||||
extracted = await agent.extract_profile_from_message(user_message, missing)
|
||||
if extracted:
|
||||
await _apply_extracted_profile(user, extracted, db)
|
||||
|
||||
remaining = _get_missing_profile_fields(user)
|
||||
filled = _get_filled_profile_fields(user)
|
||||
responses = await agent.generate_profile_followup(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
missing_fields=remaining,
|
||||
filled_fields=filled,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
|
||||
segment.agent_response = "\n\n".join(responses)
|
||||
await db.commit()
|
||||
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
if i < len(responses) - 1:
|
||||
await _asyncio.sleep(0.5)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"资料收集处理失败: {e}", exc_info=True)
|
||||
|
||||
# --- 正式访谈模式 ---
|
||||
state = await get_or_create_state(conversation.user_id, db)
|
||||
|
||||
if conversation.conversation_stage != state.current_stage:
|
||||
conversation.conversation_stage = state.current_stage
|
||||
await db.commit()
|
||||
|
||||
# 获取已聊话题(保留老逻辑用于提示)
|
||||
stmt_segments = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id
|
||||
).order_by(Segment.created_at)
|
||||
result_segments = await db.execute(stmt_segments)
|
||||
previous_segments = result_segments.scalars().all()
|
||||
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
||||
|
||||
|
||||
# 构建用户资料上下文
|
||||
user_profile_context = ""
|
||||
if user:
|
||||
from agents.prompts.profile_prompts import format_user_profile_context
|
||||
user_profile_context = format_user_profile_context(
|
||||
birth_year=user.birth_year,
|
||||
birth_place=user.birth_place,
|
||||
grew_up_place=user.grew_up_place,
|
||||
occupation=user.occupation,
|
||||
)
|
||||
|
||||
try:
|
||||
# 异步生成回应(可能是多条消息)
|
||||
responses = await agent.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=state
|
||||
memoir_state=state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
|
||||
# 更新段落的 Agent 回应(存储完整内容)
|
||||
|
||||
segment.agent_response = "\n\n".join(responses)
|
||||
await db.commit()
|
||||
|
||||
# 发送 Agent 回应(支持多条消息)
|
||||
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
@@ -439,13 +545,11 @@ async def process_user_message(
|
||||
"data": {"text": response_text, "index": i, "total": len(responses)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
# 多条消息之间稍作间隔,模拟打字效果
|
||||
if i < len(responses) - 1:
|
||||
await _asyncio.sleep(0.5)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户消息失败: {e}", exc_info=True)
|
||||
# 只在连接仍然活跃时发送错误消息
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
@@ -455,8 +559,6 @@ async def process_user_message(
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
|
||||
@@ -14,7 +14,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database.database import SessionLocal
|
||||
from database.models import Book, Chapter, Segment, MemoirState
|
||||
from database.models import Book, Chapter, Segment, MemoirState, User
|
||||
from services.llm_service import llm_service
|
||||
from agents.state_schema import MemoirStateSchema, SlotData, default_state
|
||||
from agents.prompts.memory_prompts import (
|
||||
@@ -23,6 +23,7 @@ from agents.prompts.memory_prompts import (
|
||||
get_state_extraction_prompt,
|
||||
STAGE_TO_ORDER,
|
||||
)
|
||||
from agents.prompts.profile_prompts import format_user_profile_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -179,9 +180,21 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
logger.warning(f"未找到段落: {segment_ids}")
|
||||
return {"status": "no_segments"}
|
||||
|
||||
# 获取用户状态
|
||||
# 获取用户状态和资料
|
||||
state = _get_or_create_state_sync(user_id, db)
|
||||
llm = llm_service.get_llm()
|
||||
|
||||
user_obj = db.get(User, user_id)
|
||||
user_profile = ""
|
||||
user_birth_year = None
|
||||
if user_obj:
|
||||
user_birth_year = user_obj.birth_year
|
||||
user_profile = format_user_profile_context(
|
||||
birth_year=user_obj.birth_year,
|
||||
birth_place=user_obj.birth_place,
|
||||
grew_up_place=user_obj.grew_up_place,
|
||||
occupation=user_obj.occupation,
|
||||
)
|
||||
|
||||
# 按阶段分组处理
|
||||
stage_to_segments: Dict[str, List[Segment]] = {}
|
||||
@@ -257,7 +270,9 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
title_prompt = get_creative_title_prompt(
|
||||
stage=stage,
|
||||
emotion="neutral",
|
||||
slots=slot_snippets
|
||||
slots=slot_snippets,
|
||||
user_profile=user_profile,
|
||||
birth_year=user_birth_year,
|
||||
)
|
||||
title_response = llm.invoke(title_prompt)
|
||||
title = title_response.content.strip().strip('"')
|
||||
@@ -267,6 +282,8 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
slots=slot_snippets,
|
||||
new_content=combined_text,
|
||||
existing_content=existing_content,
|
||||
user_profile=user_profile,
|
||||
birth_year=user_birth_year,
|
||||
)
|
||||
narrative_response = llm.invoke(narrative_prompt)
|
||||
new_narrative = narrative_response.content.strip()
|
||||
|
||||
Reference in New Issue
Block a user