diff --git a/.github/workflows/docker-build-deploy.yml b/.github/workflows/docker-build-deploy.yml index b6e44ee..ea9b681 100644 --- a/.github/workflows/docker-build-deploy.yml +++ b/.github/workflows/docker-build-deploy.yml @@ -7,7 +7,13 @@ on: paths: - 'api/**' - '.github/workflows/**' - workflow_dispatch: # 允许手动触发 + workflow_dispatch: + inputs: + branch: + description: '部署分支' + required: false + type: string + default: '' env: IMAGE_NAME: lifecho-api @@ -24,6 +30,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.branch || github.ref }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -77,6 +85,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.branch || github.ref }} - name: Set up SSH uses: webfactory/ssh-agent@v0.9.0 @@ -91,11 +101,12 @@ jobs: - name: Determine image tag id: image_tag run: | - if [ "${{ github.ref_name }}" == "main" ] || [ "${{ github.ref_name }}" == "master" ]; then + DEPLOY_BRANCH="${{ github.event.inputs.branch || github.ref_name }}" + echo "deploy_branch=$DEPLOY_BRANCH" >> $GITHUB_OUTPUT + if [ "$DEPLOY_BRANCH" == "main" ] || [ "$DEPLOY_BRANCH" == "master" ]; then echo "tag=latest" >> $GITHUB_OUTPUT else - # 将分支名中的斜杠替换为破折号,以符合 Docker 标签规范 - BRANCH_TAG=$(echo "${{ github.ref_name }}" | sed 's/\//-/g') + BRANCH_TAG=$(echo "$DEPLOY_BRANCH" | sed 's/\//-/g') echo "tag=$BRANCH_TAG" >> $GITHUB_OUTPUT fi @@ -214,6 +225,11 @@ jobs: "docker exec -i life-echo-postgres psql -U $DB_USER -d $DB_NAME" \ < api/migrations/add_chapter_is_active.sql + echo "添加用户基础资料字段..." + ssh -p $SSH_PORT $SSH_USER@$SSH_HOST \ + "docker exec -i life-echo-postgres psql -U $DB_USER -d $DB_NAME" \ + < api/migrations/add_user_profile_fields.sql + echo "数据库迁移完成" - name: Verify deployment diff --git a/api/agents/conversation_agent.py b/api/agents/conversation_agent.py index fd92b64..1d7ce55 100644 --- a/api/agents/conversation_agent.py +++ b/api/agents/conversation_agent.py @@ -1,7 +1,9 @@ """ 对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应 支持异步调用和 Redis 会话存储 +支持用户基础资料收集和时代背景融入 """ +import json import logging from typing import List, Optional, Dict, Any @@ -11,6 +13,13 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from services.llm_service import llm_service from services.redis_service import redis_service from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt +from .prompts.profile_prompts import ( + get_profile_greeting_prompt, + get_profile_extraction_prompt, + get_profile_followup_prompt, + format_user_profile_context, + get_missing_profile_fields, +) from .state_schema import MemoirStateSchema logger = logging.getLogger(__name__) @@ -102,6 +111,87 @@ class ConversationAgent: logger.error(f"生成回应失败: {e}") return f"抱歉,生成回应时出现错误: {str(e)}" + async def generate_profile_greeting( + self, + conversation_id: str, + missing_fields: List[str], + nickname: str = "", + ) -> List[str]: + """生成资料收集的开场白(首次对话时使用)""" + if not self.llm: + return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"] + + try: + prompt = get_profile_greeting_prompt(missing_fields, nickname) + history_messages = await self._get_history_messages(conversation_id) + history_string = self._format_history_string(history_messages) + + full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt + response = await self.llm.ainvoke(full_prompt) + response_text = response.content if hasattr(response, 'content') else str(response) + + await self._save_message(conversation_id, "ai", response_text) + + messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + return messages[:3] if messages else [response_text] + except Exception as e: + logger.error(f"生成资料收集开场白失败: {e}") + return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"] + + async def extract_profile_from_message(self, user_message: str, missing_fields: List[str]) -> Dict[str, Any]: + """从用户消息中提取基础资料信息""" + if not self.llm or not missing_fields: + return {} + + try: + prompt = get_profile_extraction_prompt(user_message, missing_fields) + response = await self.llm.ainvoke(prompt) + content = response.content.strip() + parsed = json.loads(content) + result = {} + if "birth_year" in parsed and isinstance(parsed["birth_year"], int): + result["birth_year"] = parsed["birth_year"] + if "birth_place" in parsed and parsed["birth_place"]: + result["birth_place"] = str(parsed["birth_place"]) + if "grew_up_place" in parsed and parsed["grew_up_place"]: + result["grew_up_place"] = str(parsed["grew_up_place"]) + if "occupation" in parsed and parsed["occupation"]: + result["occupation"] = str(parsed["occupation"]) + return result + except (json.JSONDecodeError, Exception) as e: + logger.error(f"提取资料信息失败: {e}") + return {} + + async def generate_profile_followup( + self, + conversation_id: str, + user_message: str, + missing_fields: List[str], + filled_fields: Dict[str, str], + nickname: str = "", + ) -> List[str]: + """在资料收集过程中生成跟进回复""" + if not self.llm: + return ["谢谢!还能告诉我更多吗?"] + + try: + prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname) + history_messages = await self._get_history_messages(conversation_id) + history_string = self._format_history_string(history_messages) + + full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" + response = await self.llm.ainvoke(full_prompt) + response_text = response.content if hasattr(response, 'content') else str(response) + + await self._save_message(conversation_id, "human", user_message) + await self._save_message(conversation_id, "ai", response_text) + + messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + return messages[:3] if messages else [response_text] + except Exception as e: + logger.error(f"生成资料跟进回复失败: {e}") + return ["谢谢分享!能再告诉我一些吗?"] + def _detect_user_stage(self, user_message: str) -> str: """ 通过关键词检测用户当前正在谈论的人生阶段。 @@ -126,7 +216,8 @@ class ConversationAgent: self, conversation_id: str, user_message: str, - memoir_state: MemoirStateSchema + memoir_state: MemoirStateSchema, + user_profile_context: str = "", ) -> List[str]: """ 基于共享状态异步生成引导式回复 @@ -135,6 +226,7 @@ class ConversationAgent: conversation_id: 对话 ID user_message: 用户消息 memoir_state: 共享状态 + user_profile_context: 用户基础资料上下文 Returns: Agent 回应文本列表(支持多条消息) @@ -150,18 +242,11 @@ class ConversationAgent: if value.snippet } - # 检测用户当前正在谈论的阶段 detected_user_stage = self._detect_user_stage(user_message) - # 从 Redis 获取对话历史,用于计算对话轮数 history_messages = await self._get_history_messages(conversation_id) - conversation_turn = len(history_messages) // 2 # 每轮包括一个用户消息和一个AI回复 - - # 计算同一话题的轮数(简单估算:基于已填充槽位的变化) - # 如果槽位数量没有增加,说明还在同一话题深入 + conversation_turn = len(history_messages) // 2 same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots) - - # 获取所有阶段的覆盖情况 all_stages_coverage = memoir_state.all_stages_coverage() system_prompt = get_guided_conversation_prompt( @@ -173,24 +258,19 @@ class ConversationAgent: same_topic_turns=same_topic_turns, all_stages_coverage=all_stages_coverage, detected_user_stage=detected_user_stage, + user_profile_context=user_profile_context, ) history_string = self._format_history_string(history_messages) - - # 构建完整 prompt full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" - # 异步调用 LLM response = await self.llm.ainvoke(full_prompt) response_text = response.content if hasattr(response, 'content') else str(response) - # 保存对话到 Redis await self._save_message(conversation_id, "human", user_message) await self._save_message(conversation_id, "ai", response_text) - # 支持多条消息,用 [SPLIT] 分隔 messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] - # 最多返回 3 条 return messages[:3] if messages else [response_text] except Exception as e: diff --git a/api/agents/prompts/__init__.py b/api/agents/prompts/__init__.py index 819c864..d5c04b2 100644 --- a/api/agents/prompts/__init__.py +++ b/api/agents/prompts/__init__.py @@ -19,6 +19,14 @@ from .memory_prompts import ( CHAPTER_ORDER, STAGE_TO_ORDER, ) +from .profile_prompts import ( + get_profile_greeting_prompt, + get_profile_extraction_prompt, + get_profile_followup_prompt, + format_user_profile_context, + get_missing_profile_fields, + PROFILE_FIELD_NAMES, +) __all__ = [ "ConversationStage", @@ -35,5 +43,11 @@ __all__ = [ "CHAPTER_CATEGORIES", "CHAPTER_ORDER", "STAGE_TO_ORDER", + "get_profile_greeting_prompt", + "get_profile_extraction_prompt", + "get_profile_followup_prompt", + "format_user_profile_context", + "get_missing_profile_fields", + "PROFILE_FIELD_NAMES", ] diff --git a/api/agents/prompts/conversation_prompts.py b/api/agents/prompts/conversation_prompts.py index 3f90e26..886f588 100644 --- a/api/agents/prompts/conversation_prompts.py +++ b/api/agents/prompts/conversation_prompts.py @@ -173,6 +173,73 @@ RESPONSE_STYLES = [ ] +def _build_era_context(current_stage: str, user_profile_context: str) -> str: + """ + 根据用户的人生阶段和出生年份,生成对应时代的历史/政治/文化背景提示。 + 让 agent 在对话中自然融入时代感。 + """ + if not user_profile_context: + return "" + + birth_year = None + birth_place = "" + for line in user_profile_context.split("\n"): + if "出生年份" in line: + try: + birth_year = int(line.split(":")[1].strip().replace("年", "")) + except (ValueError, IndexError): + pass + if "出生地" in line or "成长地" in line: + birth_place = line.split(":")[1].strip() if ":" in line else "" + + if not birth_year: + return "" + + stage_era_map = { + "childhood": (0, 12), + "education": (6, 22), + "career": (18, 50), + "family": (20, 50), + "belief": (30, 60), + } + + age_range = stage_era_map.get(current_stage, (0, 30)) + era_start = birth_year + age_range[0] + era_end = birth_year + age_range[1] + + era_events = [] + + decade_events = { + 1950: "新中国成立初期、土地改革、抗美援朝", + 1960: "大跃进、三年自然灾害、中苏关系变化", + 1970: "文化大革命、知青上山下乡、中美建交", + 1980: "改革开放、恢复高考、个体经济兴起、电视普及", + 1990: "社会主义市场经济、下海潮、香港回归、互联网初期", + 2000: "加入WTO、房地产兴起、手机普及、北京奥运", + 2010: "移动互联网爆发、微信时代、共享经济、双创浪潮", + 2020: "新冠疫情、直播经济、人工智能崛起", + } + + for decade, events in decade_events.items(): + if era_start <= decade + 9 and era_end >= decade: + era_events.append(f"{decade}年代:{events}") + + if not era_events: + return "" + + place_hint = f"(用户来自{birth_place})" if birth_place else "" + + return f""" +## 时代背景参考{place_hint} +用户在这个人生阶段大约经历了 {era_start}-{era_end} 年({age_range[0]}-{age_range[1]} 岁): +{";".join(era_events)} + +你可以在对话中自然地提及这些时代元素来丰富提问,例如: +- "那个年代好像正好赶上xxx,你们那边是什么情况?" +- "听说那时候xxx特别流行,你有印象吗?" +- 不要生硬地列举历史事件,而是像聊天一样自然带入""" + + def get_guided_conversation_prompt( current_stage: str, empty_slots: List[str], @@ -182,6 +249,7 @@ def get_guided_conversation_prompt( same_topic_turns: int = 0, all_stages_coverage: Dict[str, Dict] = None, detected_user_stage: str = "", + user_profile_context: str = "", ) -> str: """ 生成状态感知的对话提示词 @@ -195,6 +263,7 @@ def get_guided_conversation_prompt( same_topic_turns: 同一话题的轮数 all_stages_coverage: 所有阶段的覆盖情况 {stage: {total, filled, empty, ratio}} detected_user_stage: 检测到用户正在谈论的阶段(可能和 current_stage 不同) + user_profile_context: 用户基础资料上下文 """ stage_name_map = { "childhood": "童年时光", @@ -286,8 +355,16 @@ def get_guided_conversation_prompt( else: topic_desc = f"你们聊到了「{current_stage_name}」这个话题" - prompt = f"""你是「岁月知己」,用户的老朋友,正在和他/她聊人生故事。{topic_desc}。 + # --- 用户资料和时代背景 --- + profile_section = "" + if user_profile_context: + profile_section = f"\n## 用户基本信息\n{user_profile_context}\n" + active_stage = detected_user_stage if user_jumped and detected_user_stage else current_stage + era_context = _build_era_context(active_stage, user_profile_context) + + prompt = f"""你是「岁月知己」,用户的老朋友,正在和他/她聊人生故事。{topic_desc}。 +{profile_section} ## 已经聊到的内容({current_stage_name}) {filled_slots_str} @@ -296,7 +373,7 @@ def get_guided_conversation_prompt( ## 整体进度 {progress_str} - +{era_context} ## 用户刚才说 "{user_message}" @@ -309,6 +386,7 @@ def get_guided_conversation_prompt( 3. **保持自然**:不要每次都追问,有时候可以分享感受、表达好奇、或者轻松聊两句 4. **适时引导**:跟着用户的节奏聊了几轮后,如果有自然的时机,可以温和地引向还没聊到的人生阶段,但绝不要生硬 5. **追问要具体**:如果要追问,问具体的细节,比如"那时候是什么季节""身边有谁陪着你""当时心里什么感觉" +6. **融入时代感**:如果有时代背景信息,在聊天中自然地提及当时的社会环境、流行文化、历史事件,让对话更有代入感和共鸣 {dynamic_guidance}{uncovered_hint} ## 回复格式 @@ -331,6 +409,8 @@ def get_guided_conversation_prompt( - "那个年代的xxx确实是这样"(理解) - "所以后来怎么样了?"(好奇) - "对了,你刚才提到xxx,那个时候..."(换话题) +- "那会儿好像正赶上改革开放,你们那边变化大吗?"(时代融入) +- "80年代初的xxx,你还有印象吗?"(时代细节) 直接输出你要说的话(多条消息用 [SPLIT] 分隔):""" diff --git a/api/agents/prompts/memory_prompts.py b/api/agents/prompts/memory_prompts.py index 3672b9b..a2b7e90 100644 --- a/api/agents/prompts/memory_prompts.py +++ b/api/agents/prompts/memory_prompts.py @@ -2,6 +2,7 @@ 回忆录整理 Agent 提示词模板 """ import json +from typing import Optional # 章节分类映射 CHAPTER_CATEGORIES = { @@ -177,42 +178,89 @@ def get_state_extraction_prompt(user_message: str, current_stage: str, stage_slo """ -def get_creative_title_prompt(stage: str, emotion: str, slots: dict) -> str: - """生成有创意的章节标题""" +def _build_age_hint(stage: str, birth_year: Optional[int] = None) -> str: + """根据人生阶段和出生年份推算大致年龄区间""" + if not birth_year: + return "" + stage_age_ranges = { + "childhood": (0, 12), + "education": (6, 22), + "career": (18, 60), + "career_early": (18, 30), + "career_achievement": (25, 55), + "career_challenge": (20, 55), + "family": (20, 60), + "belief": (30, 70), + "beliefs": (30, 70), + "summary": (50, 80), + } + age_range = stage_age_ranges.get(stage) + if not age_range: + return "" + year_start = birth_year + age_range[0] + year_end = birth_year + age_range[1] + return f"大约 {year_start}-{year_end} 年({age_range[0]}-{age_range[1]} 岁)" + + +def get_creative_title_prompt( + stage: str, + emotion: str, + slots: dict, + user_profile: str = "", + birth_year: Optional[int] = None, +) -> str: + """生成有创意的章节标题,包含年龄/时间信息""" + age_hint = _build_age_hint(stage, birth_year) + profile_section = f"\n用户基本信息:\n{user_profile}" if user_profile else "" + time_section = f"\n时间参考:{age_hint}" if age_hint else "" + return f"""{get_system_prompt()} 请根据阶段和情绪生成 1 个有创意的章节标题。 阶段:{stage} 情绪:{emotion} -可用信息:{slots} +可用信息:{slots}{profile_section}{time_section} 要求: -1. 标题 12-18 字以内 +1. 标题格式:「时间标注 · 标题正文」 + - 时间标注用年龄或年代表示,如"6-12岁"、"1980年代"、"二十出头" + - 标题正文 12-18 字以内 2. 情绪 + 人生阶段 + 意象 3. 示例风格: - - 《那个夏天,我第一次离开家》 - - 《在陌生城市站稳脚跟》 - - 《不是所有选择都被理解》 - - 《慢下来,人生开始发声》 + - 《6-12岁 · 那条巷子尽头的蝉鸣》 + - 《18岁 · 第一次离开家的夏天》 + - 《25-35岁 · 在陌生城市站稳脚跟》 + - 《四十不惑 · 慢下来,人生开始发声》 + - 《1990年代 · 不是所有选择都被理解》 -只输出标题文字,不要加引号或其他内容。 +只输出标题文字,不要加引号或书名号。 """ -def get_narrative_prompt(stage: str, slots: dict, new_content: str, existing_content: str = "") -> str: +def get_narrative_prompt( + stage: str, + slots: dict, + new_content: str, + existing_content: str = "", + user_profile: str = "", + birth_year: Optional[int] = None, +) -> str: """将新对话改写为叙述(只输出新内容的改写,不重复已有内容)""" - # 只取已有内容的末尾作为衔接上下文 context_tail = "" if existing_content: context_tail = existing_content[-300:] if len(existing_content) > 300 else existing_content context_section = f"\n\n【衔接上下文(已有内容的末尾,仅供参考衔接,不要重复)】:\n{context_tail}" if context_tail else "" + profile_section = f"\n\n用户基本信息:\n{user_profile}" if user_profile else "" + age_hint = _build_age_hint(stage, birth_year) + time_section = f"\n时间参考:{age_hint}" if age_hint else "" + return f"""{get_system_prompt()} 请将以下新的对话内容改写为第一人称文学叙述。 阶段:{stage} -可用信息:{slots} +可用信息:{slots}{profile_section}{time_section} 新的对话内容: {new_content} @@ -225,6 +273,7 @@ def get_narrative_prompt(stage: str, slots: dict, new_content: str, existing_con 4. 如果有衔接上下文,确保新内容与之自然衔接(语气、时间线连贯) 5. 语气自然,有情绪 6. 在适合配图的地方插入图片占位符 +7. 如果有用户的基本信息(出生地、成长地等),在叙述中自然融入地域文化和时代背景 ## 图片占位符格式 在描述场景、人物、重要时刻的段落后,插入图片占位符,格式为: diff --git a/api/agents/prompts/profile_prompts.py b/api/agents/prompts/profile_prompts.py new file mode 100644 index 0000000..4ef38be --- /dev/null +++ b/api/agents/prompts/profile_prompts.py @@ -0,0 +1,163 @@ +""" +用户基础资料收集提示词 +""" +from typing import Dict, List, Optional + + +PROFILE_FIELD_NAMES = { + "birth_year": "出生年份", + "birth_place": "出生地", + "grew_up_place": "成长地", + "occupation": "职业", +} + + +def get_profile_greeting_prompt(missing_fields: List[str], nickname: str = "") -> str: + """生成初次见面、收集基础资料的引导提示词""" + missing_names = [PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES] + missing_str = "、".join(missing_names) + name_part = f",{nickname}" if nickname else "" + + return f"""你是「岁月知己」,一位温暖真诚的人生故事访谈者。你正在和用户初次见面{name_part}。 + +在正式聊人生故事之前,你需要先了解一些基本信息。还需要了解的信息有:{missing_str}。 + +## 你的任务 +用自然、亲切的方式,像老朋友聊天一样,向用户询问这些基础信息。 + +## 规则 +1. 不要一次问所有问题,每次只问 1-2 个 +2. 如果用户已经在对话中提到了某些信息,不要重复问 +3. 用口语化、亲切的方式提问 +4. 当所有信息都收集完后,自然过渡到人生故事访谈 + +## 提问示例 +- "你是哪一年出生的呀?" +- "你是在哪里出生的?小时候也是在那里长大的吗?" +- "你现在是做什么工作的呀?或者之前主要从事什么职业?" + +## 严格禁止 +- 禁止输出括号注释、思考过程 +- 禁止说"我需要收集信息"之类的机械话 +- 禁止一次列出所有问题 + +## 回复格式 +- 如果内容较多,可以用 [SPLIT] 分隔成多条消息 +- 像微信聊天一样自然 + +直接输出你要说的话:""" + + +def get_profile_extraction_prompt(user_message: str, missing_fields: List[str]) -> str: + """从用户回答中提取基础资料信息""" + missing_names = {f: PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES} + + return f"""请从用户的回答中提取基础资料信息。 + +用户的回答: +"{user_message}" + +需要提取的字段(只提取确实提到的): +{missing_names} + +请返回 JSON 格式,只包含确实提到的字段: +{{ + "birth_year": 1965, + "birth_place": "湖南长沙", + "grew_up_place": "湖南长沙", + "occupation": "教师" +}} + +规则: +1. birth_year 必须是整数(四位数年份),如"65年出生"应转为 1965 +2. 如果用户说"在老家长大"而之前提到了出生地,grew_up_place 可以和 birth_place 相同 +3. 只提取明确提到的信息,不要猜测 +4. 如果没有提取到任何信息,返回空对象 {{}} + +只返回 JSON,不要其他内容。""" + + +def get_profile_followup_prompt( + missing_fields: List[str], + filled_fields: Dict[str, str], + user_message: str, + nickname: str = "", +) -> str: + """在收集资料过程中的跟进提问""" + missing_names = [PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES] + missing_str = "、".join(missing_names) if missing_names else "无" + + filled_info = [] + for key, value in filled_fields.items(): + name = PROFILE_FIELD_NAMES.get(key, key) + filled_info.append(f"{name}: {value}") + filled_str = "\n".join(filled_info) if filled_info else "暂无" + + if not missing_names: + return f"""你是「岁月知己」。用户的基本信息已经收集完毕: +{filled_str} + +用户刚才说:"{user_message}" + +请对用户的回答做出温暖的回应,然后自然地过渡到人生故事的访谈。 +可以说类似"了解了!那我们现在开始聊聊你的人生故事吧"这样的话,然后问一个关于童年的问题作为开场。 + +回复格式:多条消息用 [SPLIT] 分隔。 +直接输出你要说的话:""" + + return f"""你是「岁月知己」,正在和用户聊天收集基本信息。 + +已知信息: +{filled_str} + +还需要了解:{missing_str} + +用户刚才说:"{user_message}" + +请先对用户说的内容做出自然回应,然后继续询问还未了解的信息(每次问 1-2 个)。 +语气要像朋友聊天一样自然亲切。 + +严格禁止: +- 禁止输出括号注释、思考过程 +- 禁止说"我注意到""我需要了解" + +回复格式:多条消息用 [SPLIT] 分隔。 +直接输出你要说的话:""" + + +def format_user_profile_context( + birth_year: Optional[int] = None, + birth_place: Optional[str] = None, + grew_up_place: Optional[str] = None, + occupation: Optional[str] = None, +) -> str: + """将用户基础信息格式化为上下文字符串,供其他 agent 使用""" + parts = [] + if birth_year: + parts.append(f"出生年份:{birth_year}年") + if birth_place: + parts.append(f"出生地:{birth_place}") + if grew_up_place: + parts.append(f"成长地:{grew_up_place}") + if occupation: + parts.append(f"职业:{occupation}") + return "\n".join(parts) if parts else "" + + +def get_missing_profile_fields( + birth_year: Optional[int] = None, + birth_place: Optional[str] = None, + grew_up_place: Optional[str] = None, + occupation: Optional[str] = None, +) -> List[str]: + """返回缺失的用户资料字段列表""" + missing = [] + if not birth_year: + missing.append("birth_year") + if not birth_place: + missing.append("birth_place") + if not grew_up_place: + missing.append("grew_up_place") + if not occupation: + missing.append("occupation") + return missing diff --git a/api/database/models.py b/api/database/models.py index 7659c9d..d16a467 100644 --- a/api/database/models.py +++ b/api/database/models.py @@ -30,6 +30,11 @@ class User(Base): subscription_expires_at = Column(DateTime(timezone=True), nullable=True) # 订阅到期时间 created_at = Column(DateTime(timezone=True), default=utc_now) + birth_year = Column(Integer, nullable=True) + birth_place = Column(String, nullable=True) + grew_up_place = Column(String, nullable=True) + occupation = Column(String, nullable=True) + # Relationships conversations = relationship("Conversation", back_populates="user") chapters = relationship("Chapter", back_populates="user") diff --git a/api/migrations/add_user_profile_fields.sql b/api/migrations/add_user_profile_fields.sql new file mode 100644 index 0000000..85e7274 --- /dev/null +++ b/api/migrations/add_user_profile_fields.sql @@ -0,0 +1,5 @@ +-- 添加用户基础资料字段(出生年份、出生地、成长地、职业) +ALTER TABLE users ADD COLUMN IF NOT EXISTS birth_year INTEGER; +ALTER TABLE users ADD COLUMN IF NOT EXISTS birth_place VARCHAR; +ALTER TABLE users ADD COLUMN IF NOT EXISTS grew_up_place VARCHAR; +ALTER TABLE users ADD COLUMN IF NOT EXISTS occupation VARCHAR; diff --git a/api/routers/user.py b/api/routers/user.py index a1b9086..36ce8bb 100644 --- a/api/routers/user.py +++ b/api/routers/user.py @@ -28,6 +28,18 @@ class UserProfileResponse(BaseModel): avatar_url: Optional[str] subscription_type: str created_at: str + birth_year: Optional[int] = None + birth_place: Optional[str] = None + grew_up_place: Optional[str] = None + occupation: Optional[str] = None + + +class UpdateUserProfileRequest(BaseModel): + """更新用户基础资料请求""" + birth_year: Optional[int] = None + birth_place: Optional[str] = None + grew_up_place: Optional[str] = None + occupation: Optional[str] = None class TestSubscriptionRequest(BaseModel): @@ -59,7 +71,43 @@ async def get_user_profile( nickname=current_user.nickname, avatar_url=current_user.avatar_url, subscription_type=current_user.subscription_type, - created_at=current_user.created_at.isoformat() + created_at=current_user.created_at.isoformat(), + birth_year=current_user.birth_year, + birth_place=current_user.birth_place, + grew_up_place=current_user.grew_up_place, + occupation=current_user.occupation, + ) + + +@router.put("/profile", response_model=UserProfileResponse) +async def update_user_profile( + body: UpdateUserProfileRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_async_db), +): + """更新用户基础资料(出生年份、出生地、成长地、职业)""" + if body.birth_year is not None: + current_user.birth_year = body.birth_year + if body.birth_place is not None: + current_user.birth_place = body.birth_place + if body.grew_up_place is not None: + current_user.grew_up_place = body.grew_up_place + if body.occupation is not None: + current_user.occupation = body.occupation + await db.commit() + await db.refresh(current_user) + return UserProfileResponse( + id=current_user.id, + phone=current_user.phone, + email=current_user.email, + nickname=current_user.nickname, + avatar_url=current_user.avatar_url, + subscription_type=current_user.subscription_type, + created_at=current_user.created_at.isoformat(), + birth_year=current_user.birth_year, + birth_place=current_user.birth_place, + grew_up_place=current_user.grew_up_place, + occupation=current_user.occupation, ) diff --git a/api/routers/websocket.py b/api/routers/websocket.py index 27b9a78..ee0c403 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -161,6 +161,28 @@ async def websocket_endpoint( return + # 首次连接时检查资料完整性,发送资料收集开场白 + missing_profile = _get_missing_profile_fields(user) + if missing_profile: + try: + greetings = await manager.conversation_agent.generate_profile_greeting( + conversation_id=conversation_id, + missing_fields=missing_profile, + nickname=user.nickname or "", + ) + import asyncio as _asyncio_greet + for i, text in enumerate(greetings): + await manager.send_message(conversation_id, { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": {"text": text, "index": i, "total": len(greetings)}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + if i < len(greetings) - 1: + await _asyncio_greet.sleep(0.5) + except Exception as e: + logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) + # 主循环:处理消息 while True: try: @@ -206,7 +228,8 @@ async def websocket_endpoint( conversation=conversation, segment=segment, db=db, - manager=manager + manager=manager, + user=user, ) elif msg_type == MessageType.AUDIO_MESSAGE: @@ -267,7 +290,8 @@ async def websocket_endpoint( conversation=conversation, segment=segment, db=db, - manager=manager + manager=manager, + user=user, ) else: # 转写失败,发送错误消息 @@ -380,58 +404,140 @@ async def websocket_endpoint( await manager.disconnect(conversation_id) +def _get_missing_profile_fields(user: UserModel) -> list: + """检查用户缺失的资料字段""" + from agents.prompts.profile_prompts import get_missing_profile_fields + return get_missing_profile_fields( + birth_year=user.birth_year, + birth_place=user.birth_place, + grew_up_place=user.grew_up_place, + occupation=user.occupation, + ) + + +def _get_filled_profile_fields(user: UserModel) -> dict: + """获取用户已有的资料字段(中文展示)""" + from agents.prompts.profile_prompts import PROFILE_FIELD_NAMES + filled = {} + if user.birth_year: + filled["birth_year"] = str(user.birth_year) + if user.birth_place: + filled["birth_place"] = user.birth_place + if user.grew_up_place: + filled["grew_up_place"] = user.grew_up_place + if user.occupation: + filled["occupation"] = user.occupation + return filled + + +async def _apply_extracted_profile(user: UserModel, extracted: dict, db: AsyncSession): + """将提取到的资料信息保存到用户模型""" + changed = False + if "birth_year" in extracted and not user.birth_year: + user.birth_year = extracted["birth_year"] + changed = True + if "birth_place" in extracted and not user.birth_place: + user.birth_place = extracted["birth_place"] + changed = True + if "grew_up_place" in extracted and not user.grew_up_place: + user.grew_up_place = extracted["grew_up_place"] + changed = True + if "occupation" in extracted and not user.occupation: + user.occupation = extracted["occupation"] + changed = True + if changed: + await db.commit() + await db.refresh(user) + + async def process_user_message( conversation_id: str, user_message: str, conversation: Conversation, segment: Segment, db: AsyncSession, - manager: ConnectionManager + manager: ConnectionManager, + user: UserModel = None, ) -> None: """ 处理用户消息,生成Agent回应(异步版本) - - Args: - conversation_id: 对话ID - user_message: 用户消息文本 - conversation: 对话对象 - segment: 段落对象 - db: 数据库会话 - manager: 连接管理器 - - Returns: - 更新后的对话阶段 + 支持资料收集模式和正式访谈模式 """ import asyncio as _asyncio - + agent = manager.conversation_agent + + # --- 资料收集模式 --- + if user: + missing = _get_missing_profile_fields(user) + if missing: + try: + extracted = await agent.extract_profile_from_message(user_message, missing) + if extracted: + await _apply_extracted_profile(user, extracted, db) + + remaining = _get_missing_profile_fields(user) + filled = _get_filled_profile_fields(user) + responses = await agent.generate_profile_followup( + conversation_id=conversation_id, + user_message=user_message, + missing_fields=remaining, + filled_fields=filled, + nickname=user.nickname or "", + ) + + segment.agent_response = "\n\n".join(responses) + await db.commit() + + for i, response_text in enumerate(responses): + await manager.send_message(conversation_id, { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": {"text": response_text, "index": i, "total": len(responses)}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + if i < len(responses) - 1: + await _asyncio.sleep(0.5) + return + except Exception as e: + logger.error(f"资料收集处理失败: {e}", exc_info=True) + + # --- 正式访谈模式 --- state = await get_or_create_state(conversation.user_id, db) if conversation.conversation_stage != state.current_stage: conversation.conversation_stage = state.current_stage await db.commit() - # 获取已聊话题(保留老逻辑用于提示) stmt_segments = select(Segment).where( Segment.conversation_id == conversation_id ).order_by(Segment.created_at) result_segments = await db.execute(stmt_segments) previous_segments = result_segments.scalars().all() covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] - + + # 构建用户资料上下文 + user_profile_context = "" + if user: + from agents.prompts.profile_prompts import format_user_profile_context + user_profile_context = format_user_profile_context( + birth_year=user.birth_year, + birth_place=user.birth_place, + grew_up_place=user.grew_up_place, + occupation=user.occupation, + ) + try: - # 异步生成回应(可能是多条消息) responses = await agent.generate_response_with_state( conversation_id=conversation_id, user_message=user_message, - memoir_state=state + memoir_state=state, + user_profile_context=user_profile_context, ) - - # 更新段落的 Agent 回应(存储完整内容) + segment.agent_response = "\n\n".join(responses) await db.commit() - - # 发送 Agent 回应(支持多条消息) + for i, response_text in enumerate(responses): await manager.send_message(conversation_id, { "type": MessageType.AGENT_RESPONSE, @@ -439,13 +545,11 @@ async def process_user_message( "data": {"text": response_text, "index": i, "total": len(responses)}, "timestamp": datetime.now(timezone.utc).isoformat() }) - # 多条消息之间稍作间隔,模拟打字效果 if i < len(responses) - 1: await _asyncio.sleep(0.5) - + except Exception as e: logger.error(f"处理用户消息失败: {e}", exc_info=True) - # 只在连接仍然活跃时发送错误消息 if conversation_id in manager.active_connections: try: await manager.send_message(conversation_id, { @@ -455,8 +559,6 @@ async def process_user_message( }) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") - - return async def process_conversation_segments(conversation_id: str, db: AsyncSession): diff --git a/api/tasks/memoir_tasks.py b/api/tasks/memoir_tasks.py index 2db5f4b..59a5237 100644 --- a/api/tasks/memoir_tasks.py +++ b/api/tasks/memoir_tasks.py @@ -14,7 +14,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from database.database import SessionLocal -from database.models import Book, Chapter, Segment, MemoirState +from database.models import Book, Chapter, Segment, MemoirState, User from services.llm_service import llm_service from agents.state_schema import MemoirStateSchema, SlotData, default_state from agents.prompts.memory_prompts import ( @@ -23,6 +23,7 @@ from agents.prompts.memory_prompts import ( get_state_extraction_prompt, STAGE_TO_ORDER, ) +from agents.prompts.profile_prompts import format_user_profile_context logger = logging.getLogger(__name__) @@ -179,9 +180,21 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): logger.warning(f"未找到段落: {segment_ids}") return {"status": "no_segments"} - # 获取用户状态 + # 获取用户状态和资料 state = _get_or_create_state_sync(user_id, db) llm = llm_service.get_llm() + + user_obj = db.get(User, user_id) + user_profile = "" + user_birth_year = None + if user_obj: + user_birth_year = user_obj.birth_year + user_profile = format_user_profile_context( + birth_year=user_obj.birth_year, + birth_place=user_obj.birth_place, + grew_up_place=user_obj.grew_up_place, + occupation=user_obj.occupation, + ) # 按阶段分组处理 stage_to_segments: Dict[str, List[Segment]] = {} @@ -257,7 +270,9 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): title_prompt = get_creative_title_prompt( stage=stage, emotion="neutral", - slots=slot_snippets + slots=slot_snippets, + user_profile=user_profile, + birth_year=user_birth_year, ) title_response = llm.invoke(title_prompt) title = title_response.content.strip().strip('"') @@ -267,6 +282,8 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): slots=slot_snippets, new_content=combined_text, existing_content=existing_content, + user_profile=user_profile, + birth_year=user_birth_year, ) narrative_response = llm.invoke(narrative_prompt) new_narrative = narrative_response.content.strip()