Merge branch 'feat/improve-agent-prompt'

This commit is contained in:
penghanyuan
2026-03-01 10:12:23 +01:00
parent a69d5c625c
commit c1e2fb31a0
11 changed files with 644 additions and 65 deletions

View File

@@ -7,7 +7,13 @@ on:
paths:
- 'api/**'
- '.github/workflows/**'
workflow_dispatch: # 允许手动触发
workflow_dispatch:
inputs:
branch:
description: '部署分支'
required: false
type: string
default: ''
env:
IMAGE_NAME: lifecho-api
@@ -24,6 +30,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.branch || github.ref }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -77,6 +85,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.branch || github.ref }}
- name: Set up SSH
uses: webfactory/ssh-agent@v0.9.0
@@ -91,11 +101,12 @@ jobs:
- name: Determine image tag
id: image_tag
run: |
if [ "${{ github.ref_name }}" == "main" ] || [ "${{ github.ref_name }}" == "master" ]; then
DEPLOY_BRANCH="${{ github.event.inputs.branch || github.ref_name }}"
echo "deploy_branch=$DEPLOY_BRANCH" >> $GITHUB_OUTPUT
if [ "$DEPLOY_BRANCH" == "main" ] || [ "$DEPLOY_BRANCH" == "master" ]; then
echo "tag=latest" >> $GITHUB_OUTPUT
else
# 将分支名中的斜杠替换为破折号,以符合 Docker 标签规范
BRANCH_TAG=$(echo "${{ github.ref_name }}" | sed 's/\//-/g')
BRANCH_TAG=$(echo "$DEPLOY_BRANCH" | sed 's/\//-/g')
echo "tag=$BRANCH_TAG" >> $GITHUB_OUTPUT
fi
@@ -214,6 +225,11 @@ jobs:
"docker exec -i life-echo-postgres psql -U $DB_USER -d $DB_NAME" \
< api/migrations/add_chapter_is_active.sql
echo "添加用户基础资料字段..."
ssh -p $SSH_PORT $SSH_USER@$SSH_HOST \
"docker exec -i life-echo-postgres psql -U $DB_USER -d $DB_NAME" \
< api/migrations/add_user_profile_fields.sql
echo "数据库迁移完成"
- name: Verify deployment

View File

@@ -1,7 +1,9 @@
"""
对话 Agent基于访谈问题清单动态选择问题实时生成回应
支持异步调用和 Redis 会话存储
支持用户基础资料收集和时代背景融入
"""
import json
import logging
from typing import List, Optional, Dict, Any
@@ -11,6 +13,13 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from services.llm_service import llm_service
from services.redis_service import redis_service
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
from .prompts.profile_prompts import (
get_profile_greeting_prompt,
get_profile_extraction_prompt,
get_profile_followup_prompt,
format_user_profile_context,
get_missing_profile_fields,
)
from .state_schema import MemoirStateSchema
logger = logging.getLogger(__name__)
@@ -102,6 +111,87 @@ class ConversationAgent:
logger.error(f"生成回应失败: {e}")
return f"抱歉,生成回应时出现错误: {str(e)}"
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""生成资料收集的开场白(首次对话时使用)"""
if not self.llm:
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
try:
prompt = get_profile_greeting_prompt(missing_fields, nickname)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成资料收集开场白失败: {e}")
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
async def extract_profile_from_message(self, user_message: str, missing_fields: List[str]) -> Dict[str, Any]:
"""从用户消息中提取基础资料信息"""
if not self.llm or not missing_fields:
return {}
try:
prompt = get_profile_extraction_prompt(user_message, missing_fields)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
result = {}
if "birth_year" in parsed and isinstance(parsed["birth_year"], int):
result["birth_year"] = parsed["birth_year"]
if "birth_place" in parsed and parsed["birth_place"]:
result["birth_place"] = str(parsed["birth_place"])
if "grew_up_place" in parsed and parsed["grew_up_place"]:
result["grew_up_place"] = str(parsed["grew_up_place"])
if "occupation" in parsed and parsed["occupation"]:
result["occupation"] = str(parsed["occupation"])
return result
except (json.JSONDecodeError, Exception) as e:
logger.error(f"提取资料信息失败: {e}")
return {}
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
) -> List[str]:
"""在资料收集过程中生成跟进回复"""
if not self.llm:
return ["谢谢!还能告诉我更多吗?"]
try:
prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成资料跟进回复失败: {e}")
return ["谢谢分享!能再告诉我一些吗?"]
def _detect_user_stage(self, user_message: str) -> str:
"""
通过关键词检测用户当前正在谈论的人生阶段。
@@ -126,7 +216,8 @@ class ConversationAgent:
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
"""
基于共享状态异步生成引导式回复
@@ -135,6 +226,7 @@ class ConversationAgent:
conversation_id: 对话 ID
user_message: 用户消息
memoir_state: 共享状态
user_profile_context: 用户基础资料上下文
Returns:
Agent 回应文本列表(支持多条消息)
@@ -150,18 +242,11 @@ class ConversationAgent:
if value.snippet
}
# 检测用户当前正在谈论的阶段
detected_user_stage = self._detect_user_stage(user_message)
# 从 Redis 获取对话历史,用于计算对话轮数
history_messages = await self._get_history_messages(conversation_id)
conversation_turn = len(history_messages) // 2 # 每轮包括一个用户消息和一个AI回复
# 计算同一话题的轮数(简单估算:基于已填充槽位的变化)
# 如果槽位数量没有增加,说明还在同一话题深入
conversation_turn = len(history_messages) // 2
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
# 获取所有阶段的覆盖情况
all_stages_coverage = memoir_state.all_stages_coverage()
system_prompt = get_guided_conversation_prompt(
@@ -173,24 +258,19 @@ class ConversationAgent:
same_topic_turns=same_topic_turns,
all_stages_coverage=all_stages_coverage,
detected_user_stage=detected_user_stage,
user_profile_context=user_profile_context,
)
history_string = self._format_history_string(history_messages)
# 构建完整 prompt
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
# 异步调用 LLM
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
# 保存对话到 Redis
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
# 支持多条消息,用 [SPLIT] 分隔
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
# 最多返回 3 条
return messages[:3] if messages else [response_text]
except Exception as e:

View File

@@ -19,6 +19,14 @@ from .memory_prompts import (
CHAPTER_ORDER,
STAGE_TO_ORDER,
)
from .profile_prompts import (
get_profile_greeting_prompt,
get_profile_extraction_prompt,
get_profile_followup_prompt,
format_user_profile_context,
get_missing_profile_fields,
PROFILE_FIELD_NAMES,
)
__all__ = [
"ConversationStage",
@@ -35,5 +43,11 @@ __all__ = [
"CHAPTER_CATEGORIES",
"CHAPTER_ORDER",
"STAGE_TO_ORDER",
"get_profile_greeting_prompt",
"get_profile_extraction_prompt",
"get_profile_followup_prompt",
"format_user_profile_context",
"get_missing_profile_fields",
"PROFILE_FIELD_NAMES",
]

View File

@@ -173,6 +173,73 @@ RESPONSE_STYLES = [
]
def _build_era_context(current_stage: str, user_profile_context: str) -> str:
"""
根据用户的人生阶段和出生年份,生成对应时代的历史/政治/文化背景提示。
让 agent 在对话中自然融入时代感。
"""
if not user_profile_context:
return ""
birth_year = None
birth_place = ""
for line in user_profile_context.split("\n"):
if "出生年份" in line:
try:
birth_year = int(line.split("")[1].strip().replace("", ""))
except (ValueError, IndexError):
pass
if "出生地" in line or "成长地" in line:
birth_place = line.split("")[1].strip() if "" in line else ""
if not birth_year:
return ""
stage_era_map = {
"childhood": (0, 12),
"education": (6, 22),
"career": (18, 50),
"family": (20, 50),
"belief": (30, 60),
}
age_range = stage_era_map.get(current_stage, (0, 30))
era_start = birth_year + age_range[0]
era_end = birth_year + age_range[1]
era_events = []
decade_events = {
1950: "新中国成立初期、土地改革、抗美援朝",
1960: "大跃进、三年自然灾害、中苏关系变化",
1970: "文化大革命、知青上山下乡、中美建交",
1980: "改革开放、恢复高考、个体经济兴起、电视普及",
1990: "社会主义市场经济、下海潮、香港回归、互联网初期",
2000: "加入WTO、房地产兴起、手机普及、北京奥运",
2010: "移动互联网爆发、微信时代、共享经济、双创浪潮",
2020: "新冠疫情、直播经济、人工智能崛起",
}
for decade, events in decade_events.items():
if era_start <= decade + 9 and era_end >= decade:
era_events.append(f"{decade}年代:{events}")
if not era_events:
return ""
place_hint = f"(用户来自{birth_place}" if birth_place else ""
return f"""
## 时代背景参考{place_hint}
用户在这个人生阶段大约经历了 {era_start}-{era_end} 年({age_range[0]}-{age_range[1]} 岁):
{"".join(era_events)}
你可以在对话中自然地提及这些时代元素来丰富提问,例如:
- "那个年代好像正好赶上xxx你们那边是什么情况"
- "听说那时候xxx特别流行你有印象吗"
- 不要生硬地列举历史事件,而是像聊天一样自然带入"""
def get_guided_conversation_prompt(
current_stage: str,
empty_slots: List[str],
@@ -182,6 +249,7 @@ def get_guided_conversation_prompt(
same_topic_turns: int = 0,
all_stages_coverage: Dict[str, Dict] = None,
detected_user_stage: str = "",
user_profile_context: str = "",
) -> str:
"""
生成状态感知的对话提示词
@@ -195,6 +263,7 @@ def get_guided_conversation_prompt(
same_topic_turns: 同一话题的轮数
all_stages_coverage: 所有阶段的覆盖情况 {stage: {total, filled, empty, ratio}}
detected_user_stage: 检测到用户正在谈论的阶段(可能和 current_stage 不同)
user_profile_context: 用户基础资料上下文
"""
stage_name_map = {
"childhood": "童年时光",
@@ -286,8 +355,16 @@ def get_guided_conversation_prompt(
else:
topic_desc = f"你们聊到了「{current_stage_name}」这个话题"
prompt = f"""你是「岁月知己」,用户的老朋友,正在和他/她聊人生故事。{topic_desc}
# --- 用户资料和时代背景 ---
profile_section = ""
if user_profile_context:
profile_section = f"\n## 用户基本信息\n{user_profile_context}\n"
active_stage = detected_user_stage if user_jumped and detected_user_stage else current_stage
era_context = _build_era_context(active_stage, user_profile_context)
prompt = f"""你是「岁月知己」,用户的老朋友,正在和他/她聊人生故事。{topic_desc}
{profile_section}
## 已经聊到的内容({current_stage_name}
{filled_slots_str}
@@ -296,7 +373,7 @@ def get_guided_conversation_prompt(
## 整体进度
{progress_str}
{era_context}
## 用户刚才说
"{user_message}"
@@ -309,6 +386,7 @@ def get_guided_conversation_prompt(
3. **保持自然**:不要每次都追问,有时候可以分享感受、表达好奇、或者轻松聊两句
4. **适时引导**:跟着用户的节奏聊了几轮后,如果有自然的时机,可以温和地引向还没聊到的人生阶段,但绝不要生硬
5. **追问要具体**:如果要追问,问具体的细节,比如"那时候是什么季节""身边有谁陪着你""当时心里什么感觉"
6. **融入时代感**:如果有时代背景信息,在聊天中自然地提及当时的社会环境、流行文化、历史事件,让对话更有代入感和共鸣
{dynamic_guidance}{uncovered_hint}
## 回复格式
@@ -331,6 +409,8 @@ def get_guided_conversation_prompt(
- "那个年代的xxx确实是这样"(理解)
- "所以后来怎么样了?"(好奇)
- "对了你刚才提到xxx那个时候..."(换话题)
- "那会儿好像正赶上改革开放,你们那边变化大吗?"(时代融入)
- "80年代初的xxx你还有印象吗"(时代细节)
直接输出你要说的话(多条消息用 [SPLIT] 分隔):"""

View File

@@ -2,6 +2,7 @@
回忆录整理 Agent 提示词模板
"""
import json
from typing import Optional
# 章节分类映射
CHAPTER_CATEGORIES = {
@@ -177,42 +178,89 @@ def get_state_extraction_prompt(user_message: str, current_stage: str, stage_slo
"""
def get_creative_title_prompt(stage: str, emotion: str, slots: dict) -> str:
"""生成有创意的章节标题"""
def _build_age_hint(stage: str, birth_year: Optional[int] = None) -> str:
"""根据人生阶段和出生年份推算大致年龄区间"""
if not birth_year:
return ""
stage_age_ranges = {
"childhood": (0, 12),
"education": (6, 22),
"career": (18, 60),
"career_early": (18, 30),
"career_achievement": (25, 55),
"career_challenge": (20, 55),
"family": (20, 60),
"belief": (30, 70),
"beliefs": (30, 70),
"summary": (50, 80),
}
age_range = stage_age_ranges.get(stage)
if not age_range:
return ""
year_start = birth_year + age_range[0]
year_end = birth_year + age_range[1]
return f"大约 {year_start}-{year_end} 年({age_range[0]}-{age_range[1]} 岁)"
def get_creative_title_prompt(
stage: str,
emotion: str,
slots: dict,
user_profile: str = "",
birth_year: Optional[int] = None,
) -> str:
"""生成有创意的章节标题,包含年龄/时间信息"""
age_hint = _build_age_hint(stage, birth_year)
profile_section = f"\n用户基本信息:\n{user_profile}" if user_profile else ""
time_section = f"\n时间参考:{age_hint}" if age_hint else ""
return f"""{get_system_prompt()}
请根据阶段和情绪生成 1 个有创意的章节标题。
阶段:{stage}
情绪:{emotion}
可用信息:{slots}
可用信息:{slots}{profile_section}{time_section}
要求:
1. 标题 12-18 字以内
1. 标题格式:「时间标注 · 标题正文」
- 时间标注用年龄或年代表示,如"6-12岁""1980年代""二十出头"
- 标题正文 12-18 字以内
2. 情绪 + 人生阶段 + 意象
3. 示例风格:
- 《那个夏天,我第一次离开家
- 《在陌生城市站稳脚跟
- 《不是所有选择都被理解
- 《慢下来,人生开始发声》
- 《6-12岁 · 那条巷子尽头的蝉鸣
- 《18岁 · 第一次离开家的夏天
- 《25-35岁 · 在陌生城市站稳脚跟
- 《四十不惑 · 慢下来,人生开始发声》
- 《1990年代 · 不是所有选择都被理解》
只输出标题文字,不要加引号或其他内容
只输出标题文字,不要加引号或书名号
"""
def get_narrative_prompt(stage: str, slots: dict, new_content: str, existing_content: str = "") -> str:
def get_narrative_prompt(
stage: str,
slots: dict,
new_content: str,
existing_content: str = "",
user_profile: str = "",
birth_year: Optional[int] = None,
) -> str:
"""将新对话改写为叙述(只输出新内容的改写,不重复已有内容)"""
# 只取已有内容的末尾作为衔接上下文
context_tail = ""
if existing_content:
context_tail = existing_content[-300:] if len(existing_content) > 300 else existing_content
context_section = f"\n\n【衔接上下文(已有内容的末尾,仅供参考衔接,不要重复)】:\n{context_tail}" if context_tail else ""
profile_section = f"\n\n用户基本信息:\n{user_profile}" if user_profile else ""
age_hint = _build_age_hint(stage, birth_year)
time_section = f"\n时间参考:{age_hint}" if age_hint else ""
return f"""{get_system_prompt()}
请将以下新的对话内容改写为第一人称文学叙述。
阶段:{stage}
可用信息:{slots}
可用信息:{slots}{profile_section}{time_section}
新的对话内容:
{new_content}
@@ -225,6 +273,7 @@ def get_narrative_prompt(stage: str, slots: dict, new_content: str, existing_con
4. 如果有衔接上下文,确保新内容与之自然衔接(语气、时间线连贯)
5. 语气自然,有情绪
6. 在适合配图的地方插入图片占位符
7. 如果有用户的基本信息(出生地、成长地等),在叙述中自然融入地域文化和时代背景
## 图片占位符格式
在描述场景、人物、重要时刻的段落后,插入图片占位符,格式为:

View File

@@ -0,0 +1,163 @@
"""
用户基础资料收集提示词
"""
from typing import Dict, List, Optional
PROFILE_FIELD_NAMES = {
"birth_year": "出生年份",
"birth_place": "出生地",
"grew_up_place": "成长地",
"occupation": "职业",
}
def get_profile_greeting_prompt(missing_fields: List[str], nickname: str = "") -> str:
"""生成初次见面、收集基础资料的引导提示词"""
missing_names = [PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES]
missing_str = "".join(missing_names)
name_part = f"{nickname}" if nickname else ""
return f"""你是「岁月知己」,一位温暖真诚的人生故事访谈者。你正在和用户初次见面{name_part}
在正式聊人生故事之前,你需要先了解一些基本信息。还需要了解的信息有:{missing_str}
## 你的任务
用自然、亲切的方式,像老朋友聊天一样,向用户询问这些基础信息。
## 规则
1. 不要一次问所有问题,每次只问 1-2 个
2. 如果用户已经在对话中提到了某些信息,不要重复问
3. 用口语化、亲切的方式提问
4. 当所有信息都收集完后,自然过渡到人生故事访谈
## 提问示例
- "你是哪一年出生的呀?"
- "你是在哪里出生的?小时候也是在那里长大的吗?"
- "你现在是做什么工作的呀?或者之前主要从事什么职业?"
## 严格禁止
- 禁止输出括号注释、思考过程
- 禁止说"我需要收集信息"之类的机械话
- 禁止一次列出所有问题
## 回复格式
- 如果内容较多,可以用 [SPLIT] 分隔成多条消息
- 像微信聊天一样自然
直接输出你要说的话:"""
def get_profile_extraction_prompt(user_message: str, missing_fields: List[str]) -> str:
"""从用户回答中提取基础资料信息"""
missing_names = {f: PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES}
return f"""请从用户的回答中提取基础资料信息。
用户的回答:
"{user_message}"
需要提取的字段(只提取确实提到的):
{missing_names}
请返回 JSON 格式,只包含确实提到的字段:
{{
"birth_year": 1965,
"birth_place": "湖南长沙",
"grew_up_place": "湖南长沙",
"occupation": "教师"
}}
规则:
1. birth_year 必须是整数(四位数年份),如"65年出生"应转为 1965
2. 如果用户说"在老家长大"而之前提到了出生地grew_up_place 可以和 birth_place 相同
3. 只提取明确提到的信息,不要猜测
4. 如果没有提取到任何信息,返回空对象 {{}}
只返回 JSON不要其他内容。"""
def get_profile_followup_prompt(
missing_fields: List[str],
filled_fields: Dict[str, str],
user_message: str,
nickname: str = "",
) -> str:
"""在收集资料过程中的跟进提问"""
missing_names = [PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES]
missing_str = "".join(missing_names) if missing_names else ""
filled_info = []
for key, value in filled_fields.items():
name = PROFILE_FIELD_NAMES.get(key, key)
filled_info.append(f"{name}: {value}")
filled_str = "\n".join(filled_info) if filled_info else "暂无"
if not missing_names:
return f"""你是「岁月知己」。用户的基本信息已经收集完毕:
{filled_str}
用户刚才说:"{user_message}"
请对用户的回答做出温暖的回应,然后自然地过渡到人生故事的访谈。
可以说类似"了解了!那我们现在开始聊聊你的人生故事吧"这样的话,然后问一个关于童年的问题作为开场。
回复格式:多条消息用 [SPLIT] 分隔。
直接输出你要说的话:"""
return f"""你是「岁月知己」,正在和用户聊天收集基本信息。
已知信息:
{filled_str}
还需要了解:{missing_str}
用户刚才说:"{user_message}"
请先对用户说的内容做出自然回应,然后继续询问还未了解的信息(每次问 1-2 个)。
语气要像朋友聊天一样自然亲切。
严格禁止:
- 禁止输出括号注释、思考过程
- 禁止说"我注意到""我需要了解"
回复格式:多条消息用 [SPLIT] 分隔。
直接输出你要说的话:"""
def format_user_profile_context(
birth_year: Optional[int] = None,
birth_place: Optional[str] = None,
grew_up_place: Optional[str] = None,
occupation: Optional[str] = None,
) -> str:
"""将用户基础信息格式化为上下文字符串,供其他 agent 使用"""
parts = []
if birth_year:
parts.append(f"出生年份:{birth_year}")
if birth_place:
parts.append(f"出生地:{birth_place}")
if grew_up_place:
parts.append(f"成长地:{grew_up_place}")
if occupation:
parts.append(f"职业:{occupation}")
return "\n".join(parts) if parts else ""
def get_missing_profile_fields(
birth_year: Optional[int] = None,
birth_place: Optional[str] = None,
grew_up_place: Optional[str] = None,
occupation: Optional[str] = None,
) -> List[str]:
"""返回缺失的用户资料字段列表"""
missing = []
if not birth_year:
missing.append("birth_year")
if not birth_place:
missing.append("birth_place")
if not grew_up_place:
missing.append("grew_up_place")
if not occupation:
missing.append("occupation")
return missing

View File

@@ -30,6 +30,11 @@ class User(Base):
subscription_expires_at = Column(DateTime(timezone=True), nullable=True) # 订阅到期时间
created_at = Column(DateTime(timezone=True), default=utc_now)
birth_year = Column(Integer, nullable=True)
birth_place = Column(String, nullable=True)
grew_up_place = Column(String, nullable=True)
occupation = Column(String, nullable=True)
# Relationships
conversations = relationship("Conversation", back_populates="user")
chapters = relationship("Chapter", back_populates="user")

View File

@@ -0,0 +1,5 @@
-- 添加用户基础资料字段(出生年份、出生地、成长地、职业)
ALTER TABLE users ADD COLUMN IF NOT EXISTS birth_year INTEGER;
ALTER TABLE users ADD COLUMN IF NOT EXISTS birth_place VARCHAR;
ALTER TABLE users ADD COLUMN IF NOT EXISTS grew_up_place VARCHAR;
ALTER TABLE users ADD COLUMN IF NOT EXISTS occupation VARCHAR;

View File

@@ -28,6 +28,18 @@ class UserProfileResponse(BaseModel):
avatar_url: Optional[str]
subscription_type: str
created_at: str
birth_year: Optional[int] = None
birth_place: Optional[str] = None
grew_up_place: Optional[str] = None
occupation: Optional[str] = None
class UpdateUserProfileRequest(BaseModel):
"""更新用户基础资料请求"""
birth_year: Optional[int] = None
birth_place: Optional[str] = None
grew_up_place: Optional[str] = None
occupation: Optional[str] = None
class TestSubscriptionRequest(BaseModel):
@@ -59,7 +71,43 @@ async def get_user_profile(
nickname=current_user.nickname,
avatar_url=current_user.avatar_url,
subscription_type=current_user.subscription_type,
created_at=current_user.created_at.isoformat()
created_at=current_user.created_at.isoformat(),
birth_year=current_user.birth_year,
birth_place=current_user.birth_place,
grew_up_place=current_user.grew_up_place,
occupation=current_user.occupation,
)
@router.put("/profile", response_model=UserProfileResponse)
async def update_user_profile(
body: UpdateUserProfileRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db),
):
"""更新用户基础资料(出生年份、出生地、成长地、职业)"""
if body.birth_year is not None:
current_user.birth_year = body.birth_year
if body.birth_place is not None:
current_user.birth_place = body.birth_place
if body.grew_up_place is not None:
current_user.grew_up_place = body.grew_up_place
if body.occupation is not None:
current_user.occupation = body.occupation
await db.commit()
await db.refresh(current_user)
return UserProfileResponse(
id=current_user.id,
phone=current_user.phone,
email=current_user.email,
nickname=current_user.nickname,
avatar_url=current_user.avatar_url,
subscription_type=current_user.subscription_type,
created_at=current_user.created_at.isoformat(),
birth_year=current_user.birth_year,
birth_place=current_user.birth_place,
grew_up_place=current_user.grew_up_place,
occupation=current_user.occupation,
)

View File

@@ -161,6 +161,28 @@ async def websocket_endpoint(
return
# 首次连接时检查资料完整性,发送资料收集开场白
missing_profile = _get_missing_profile_fields(user)
if missing_profile:
try:
greetings = await manager.conversation_agent.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_profile,
nickname=user.nickname or "",
)
import asyncio as _asyncio_greet
for i, text in enumerate(greetings):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": text, "index": i, "total": len(greetings)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(greetings) - 1:
await _asyncio_greet.sleep(0.5)
except Exception as e:
logger.error(f"发送资料收集开场白失败: {e}", exc_info=True)
# 主循环:处理消息
while True:
try:
@@ -206,7 +228,8 @@ async def websocket_endpoint(
conversation=conversation,
segment=segment,
db=db,
manager=manager
manager=manager,
user=user,
)
elif msg_type == MessageType.AUDIO_MESSAGE:
@@ -267,7 +290,8 @@ async def websocket_endpoint(
conversation=conversation,
segment=segment,
db=db,
manager=manager
manager=manager,
user=user,
)
else:
# 转写失败,发送错误消息
@@ -380,58 +404,140 @@ async def websocket_endpoint(
await manager.disconnect(conversation_id)
def _get_missing_profile_fields(user: UserModel) -> list:
"""检查用户缺失的资料字段"""
from agents.prompts.profile_prompts import get_missing_profile_fields
return get_missing_profile_fields(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
def _get_filled_profile_fields(user: UserModel) -> dict:
"""获取用户已有的资料字段(中文展示)"""
from agents.prompts.profile_prompts import PROFILE_FIELD_NAMES
filled = {}
if user.birth_year:
filled["birth_year"] = str(user.birth_year)
if user.birth_place:
filled["birth_place"] = user.birth_place
if user.grew_up_place:
filled["grew_up_place"] = user.grew_up_place
if user.occupation:
filled["occupation"] = user.occupation
return filled
async def _apply_extracted_profile(user: UserModel, extracted: dict, db: AsyncSession):
"""将提取到的资料信息保存到用户模型"""
changed = False
if "birth_year" in extracted and not user.birth_year:
user.birth_year = extracted["birth_year"]
changed = True
if "birth_place" in extracted and not user.birth_place:
user.birth_place = extracted["birth_place"]
changed = True
if "grew_up_place" in extracted and not user.grew_up_place:
user.grew_up_place = extracted["grew_up_place"]
changed = True
if "occupation" in extracted and not user.occupation:
user.occupation = extracted["occupation"]
changed = True
if changed:
await db.commit()
await db.refresh(user)
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager
manager: ConnectionManager,
user: UserModel = None,
) -> None:
"""
处理用户消息生成Agent回应异步版本
Args:
conversation_id: 对话ID
user_message: 用户消息文本
conversation: 对话对象
segment: 段落对象
db: 数据库会话
manager: 连接管理器
Returns:
更新后的对话阶段
支持资料收集模式和正式访谈模式
"""
import asyncio as _asyncio
agent = manager.conversation_agent
# --- 资料收集模式 ---
if user:
missing = _get_missing_profile_fields(user)
if missing:
try:
extracted = await agent.extract_profile_from_message(user_message, missing)
if extracted:
await _apply_extracted_profile(user, extracted, db)
remaining = _get_missing_profile_fields(user)
filled = _get_filled_profile_fields(user)
responses = await agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=remaining,
filled_fields=filled,
nickname=user.nickname or "",
)
segment.agent_response = "\n\n".join(responses)
await db.commit()
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
return
except Exception as e:
logger.error(f"资料收集处理失败: {e}", exc_info=True)
# --- 正式访谈模式 ---
state = await get_or_create_state(conversation.user_id, db)
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
# 获取已聊话题(保留老逻辑用于提示)
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 构建用户资料上下文
user_profile_context = ""
if user:
from agents.prompts.profile_prompts import format_user_profile_context
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
try:
# 异步生成回应(可能是多条消息)
responses = await agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state
memoir_state=state,
user_profile_context=user_profile_context,
)
# 更新段落的 Agent 回应(存储完整内容)
segment.agent_response = "\n\n".join(responses)
await db.commit()
# 发送 Agent 回应(支持多条消息)
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
@@ -439,13 +545,11 @@ async def process_user_message(
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 多条消息之间稍作间隔,模拟打字效果
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
except Exception as e:
logger.error(f"处理用户消息失败: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
@@ -455,8 +559,6 @@ async def process_user_message(
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
return
async def process_conversation_segments(conversation_id: str, db: AsyncSession):

View File

@@ -14,7 +14,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from database.database import SessionLocal
from database.models import Book, Chapter, Segment, MemoirState
from database.models import Book, Chapter, Segment, MemoirState, User
from services.llm_service import llm_service
from agents.state_schema import MemoirStateSchema, SlotData, default_state
from agents.prompts.memory_prompts import (
@@ -23,6 +23,7 @@ from agents.prompts.memory_prompts import (
get_state_extraction_prompt,
STAGE_TO_ORDER,
)
from agents.prompts.profile_prompts import format_user_profile_context
logger = logging.getLogger(__name__)
@@ -179,9 +180,21 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
logger.warning(f"未找到段落: {segment_ids}")
return {"status": "no_segments"}
# 获取用户状态
# 获取用户状态和资料
state = _get_or_create_state_sync(user_id, db)
llm = llm_service.get_llm()
user_obj = db.get(User, user_id)
user_profile = ""
user_birth_year = None
if user_obj:
user_birth_year = user_obj.birth_year
user_profile = format_user_profile_context(
birth_year=user_obj.birth_year,
birth_place=user_obj.birth_place,
grew_up_place=user_obj.grew_up_place,
occupation=user_obj.occupation,
)
# 按阶段分组处理
stage_to_segments: Dict[str, List[Segment]] = {}
@@ -257,7 +270,9 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
title_prompt = get_creative_title_prompt(
stage=stage,
emotion="neutral",
slots=slot_snippets
slots=slot_snippets,
user_profile=user_profile,
birth_year=user_birth_year,
)
title_response = llm.invoke(title_prompt)
title = title_response.content.strip().strip('"')
@@ -267,6 +282,8 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
slots=slot_snippets,
new_content=combined_text,
existing_content=existing_content,
user_profile=user_profile,
birth_year=user_birth_year,
)
narrative_response = llm.invoke(narrative_prompt)
new_narrative = narrative_response.content.strip()