""" 重新整理用户历史对话为回忆录章节(远程预览 + 确认后写入) 用法: cd api # 第一步:预览(只读远程 DB,本地生成新章节,输出对比 Markdown) python -m scripts.reprocess_user_memoir preview --phone 13800138000 # 第二步:确认后写入远程 DB python -m scripts.reprocess_user_memoir apply --phone 13800138000 流程: preview: 1. SSH 隧道连接远程 PostgreSQL 2. 读取用户现有章节 + 所有历史对话段落 3. 本地调用 LLM 生成新章节(不写入远程 DB) 4. 输出对比 Markdown 表格 + 保存结果到 JSON 文件 apply: 1. 读取上次 preview 保存的 JSON 文件 2. SSH 隧道连接远程 PostgreSQL 3. 旧章节 is_active=False,写入新章节 """ import argparse import json import os import sys import uuid import time from datetime import datetime, timezone from typing import Dict, List, Optional from dataclasses import dataclass, field, asdict # 确保 api/ 目录在 sys.path 中 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 配置由 app.core.config.settings 统一加载 import socket import subprocess import signal from sqlalchemy import create_engine, select from sqlalchemy.orm import sessionmaker, Session from app.core.db import Base from app.features.conversation.models import Conversation, Segment from app.features.memoir.models import Book, Chapter, ChapterSection, MemoirState from app.features.user.models import User from app.core.dependencies import get_llm_provider from app.agents.state_schema import MemoirStateSchema, SlotData, default_state from app.agents.memoir.prompts import ( get_creative_title_prompt, get_narrative_prompt, get_state_extraction_prompt, inject_image_placeholder_template, STAGE_TO_ORDER, ) from app.features.memoir.memoir_images.json_payload import extract_json_payload from app.features.memoir.memoir_images.parser import split_narrative_to_sections from app.core.logging import get_logger, setup_logging setup_logging() logger = get_logger(__name__) # ── SSH / DB 配置 ────────────────────────────────────────────── SSH_HOST = "1.15.29.57" SSH_PORT = 22 SSH_USER = "root" SSH_KEY_PATH = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "..", "certs", "key.crt", ) REMOTE_PG_HOST = "127.0.0.1" REMOTE_PG_PORT = 5432 PG_USER = "postgres" PG_PASSWORD = "postgres" PG_DATABASE = "life_echo" OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") # ── 关键字阶段检测 ──────────────────────────────────────────── STAGE_KEYWORDS = { "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], "education": ["上学", "学校", "老师", "同学", "教育", "大学"], "career": ["工作", "职业", "事业", "公司", "同事", "创业"], "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"], "belief": ["信念", "价值观", "座右铭", "坚持", "原则"], } def _detect_stage(text: str, fallback: str) -> str: msg = text.lower() for stage, keywords in STAGE_KEYWORDS.items(): if any(w in msg for w in keywords): return stage return fallback # ── SSH 隧道 + DB 会话 ──────────────────────────────────────── class SshTunnel: """用 ssh -L 子进程建立隧道,兼容所有 paramiko 版本""" def __init__(self, local_port: int = 15432): self.local_port = local_port self._proc: Optional[subprocess.Popen] = None def start(self): key_path = os.path.normpath(SSH_KEY_PATH) cmd = [ "ssh", "-N", "-L", f"{self.local_port}:{REMOTE_PG_HOST}:{REMOTE_PG_PORT}", "-i", key_path, "-p", str(SSH_PORT), "-o", "StrictHostKeyChecking=no", "-o", "ExitOnForwardFailure=yes", "-o", "BatchMode=yes", f"{SSH_USER}@{SSH_HOST}", ] logger.info(f"SSH 隧道: {SSH_USER}@{SSH_HOST}:{SSH_PORT} -> 127.0.0.1:{self.local_port}, key={key_path}") self._proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) # 等待隧道端口可连接(最多 15 秒) for attempt in range(30): if self._proc.poll() is not None: err = self._proc.stderr.read().decode() if self._proc.stderr else "" raise RuntimeError(f"SSH 隧道进程已退出: {err}") try: sock = socket.create_connection(("127.0.0.1", self.local_port), timeout=1) sock.close() logger.info(f"SSH 隧道已建立, 本地端口: {self.local_port} (耗时 {attempt * 0.5:.1f}s)") return except (ConnectionRefusedError, OSError): time.sleep(0.5) # 超时 err = "" if self._proc.poll() is not None and self._proc.stderr: err = self._proc.stderr.read().decode() raise RuntimeError(f"SSH 隧道端口 {self.local_port} 超时未就绪: {err}") def stop(self): if self._proc and self._proc.poll() is None: self._proc.send_signal(signal.SIGTERM) self._proc.wait(timeout=5) logger.info("SSH 隧道已关闭") @property def local_bind_port(self) -> int: return self.local_port def open_ssh_tunnel() -> SshTunnel: tunnel = SshTunnel() tunnel.start() return tunnel def make_session(tunnel: SshTunnel) -> Session: url = ( f"postgresql://{PG_USER}:{PG_PASSWORD}" f"@127.0.0.1:{tunnel.local_bind_port}/{PG_DATABASE}" ) engine = create_engine(url, pool_size=2, max_overflow=2) return sessionmaker(bind=engine)() # ── 数据结构:保存生成结果 ──────────────────────────────────── @dataclass class GeneratedChapter: category: str title: str content: str order_index: int source_segment_ids: List[str] = field(default_factory=list) @dataclass class PreviewResult: user_id: str phone: str nickname: str generated_at: str old_chapters: List[dict] = field(default_factory=list) # {category, title, content_len, content_preview} new_chapters: List[dict] = field(default_factory=list) # same shape + full content # ── 核心:本地生成章节 ──────────────────────────────────────── def extract_slots_with_llm(llm, text: str, current_stage: str, stage_slots: dict): try: prompt = get_state_extraction_prompt( user_message=text, current_stage=current_stage, stage_slots=stage_slots, ) json_llm = llm.bind( model_kwargs={"response_format": {"type": "json_object"}}, max_tokens=1024, ) response = json_llm.invoke(prompt) parsed = json.loads(extract_json_payload(response.content.strip())) return parsed.get("detected_stage", current_stage), parsed.get("slots", {}) or {} except Exception as e: logger.warning(f"LLM slot 提取失败: {e}") return current_stage, {} def generate_chapters_in_memory( segments: list, # list of (id, transcript_text) llm, batch_size: int, skip_llm_slots: bool, ) -> List[GeneratedChapter]: """纯内存生成章节,不写任何 DB""" state = default_state() # 1. 阶段检测 & slot 提取(内存 state) stage_to_segments: Dict[str, list] = {} for idx, (seg_id, text) in enumerate(segments, 1): if not text or not text.strip(): continue detected_stage = _detect_stage(text, state.current_stage) if not skip_llm_slots: try: detected_stage, extracted_slots = extract_slots_with_llm( llm, text, state.current_stage, state.slots.get(detected_stage, {}) ) # 内存更新 state slots for slot_name, snippet in extracted_slots.items(): stage_slots = state.slots.get(detected_stage, {}) stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=[seg_id]) state.slots[detected_stage] = stage_slots state.current_stage = detected_stage except Exception as e: logger.warning(f"段落 {idx} slot 提取失败: {e}") stage_to_segments.setdefault(detected_stage, []).append((seg_id, text)) if idx % 20 == 0: logger.info(f"阶段检测进度: {idx}/{len(segments)}") for stage, segs in stage_to_segments.items(): logger.info(f"阶段 [{stage}]: {len(segs)} 条段落") # 2. 按阶段分批生成 results: List[GeneratedChapter] = [] for stage, seg_list in stage_to_segments.items(): title = f"{stage} 回忆" existing_content = "" all_source_ids: List[str] = [] slot_snippets = { key: value.snippet for key, value in (state.slots.get(stage, {}) or {}).items() if value.snippet } for i in range(0, len(seg_list), batch_size): batch = seg_list[i : i + batch_size] batch_num = i // batch_size + 1 total_batches = (len(seg_list) + batch_size - 1) // batch_size logger.info(f"[{stage}] 处理第 {batch_num}/{total_batches} 批 ({len(batch)} 条)") combined_text = "\n\n".join(text for _, text in batch) source_ids = [sid for sid, _ in batch] all_source_ids.extend(source_ids) narrative = combined_text # fallback try: if not existing_content: # 第一批 → 生成标题 title_prompt = get_creative_title_prompt( stage=stage, emotion="neutral", slots=slot_snippets ) title_response = llm.invoke(title_prompt) title = title_response.content.strip().strip('"') logger.info(f"[{stage}] 生成标题: {title}") narrative_prompt = get_narrative_prompt( stage=stage, slots=slot_snippets, new_content=combined_text, existing_content=existing_content, ) narrative_response = llm.invoke(narrative_prompt) new_narrative = narrative_response.content.strip() if existing_content: narrative = f"{existing_content}\n\n{new_narrative}" else: narrative = new_narrative except Exception as e: logger.warning(f"[{stage}] LLM 生成失败: {e}") if existing_content: narrative = f"{existing_content}\n\n{combined_text}" # 安全检查 if existing_content and len(narrative) < len(existing_content) * 0.8: logger.warning(f"[{stage}] 内容长度异常, 回退追加模式") narrative = f"{existing_content}\n\n{combined_text}" existing_content = narrative logger.info(f"[{stage}] 批次 {batch_num} 完成, 累计长度: {len(existing_content)} 字") if i + batch_size < len(seg_list): time.sleep(1) # 入库前:占位符位置用正则匹配后拼上固定模板 content_to_save = inject_image_placeholder_template(existing_content) results.append(GeneratedChapter( category=stage, title=title, content=content_to_save, order_index=STAGE_TO_ORDER.get(stage, 999), source_segment_ids=all_source_ids, )) return results # ── preview 命令 ────────────────────────────────────────────── def cmd_preview(phone: str, batch_size: int, skip_llm_slots: bool): # LLM llm = getattr(get_llm_provider(), "langchain_llm", None) if not llm: logger.error("LLM 未配置,请检查 .env 中的 DEEPSEEK_API_KEY") sys.exit(1) logger.info("LLM 就绪") tunnel = open_ssh_tunnel() try: db = make_session(tunnel) try: # 找用户 user = db.execute(select(User).where(User.phone == phone)).scalar_one_or_none() if not user: logger.error(f"未找到手机号 {phone} 的用户") sys.exit(1) user_id = user.id nickname = user.nickname logger.info(f"用户: {nickname} (id={user_id})") # 读取现有 active 章节(含 sections,正文从 sections 拼接) from sqlalchemy.orm import joinedload old_chapters = ( db.execute( select(Chapter) .where(Chapter.user_id == user_id, Chapter.is_active == True) .options(joinedload(Chapter.sections)) .order_by(Chapter.order_index) ) .unique() .scalars() .all() ) old_chapter_data = [] for ch in old_chapters: content = "" if getattr(ch, "sections", None): content = "\n\n".join( (s.content or "").strip() for s in sorted(ch.sections, key=lambda x: x.order_index) if (s.content or "").strip() ) content_len = len(content) content_preview = (content[:200] + "…") if content_len > 200 else content old_chapter_data.append({ "category": ch.category, "title": ch.title, "content_len": content_len, "content_preview": content_preview, }) logger.info(f"现有章节: {len(old_chapters)} 个") # 读取所有段落 segments_raw = ( db.execute( select(Segment.id, Segment.transcript_text) .join(Conversation, Segment.conversation_id == Conversation.id) .where(Conversation.user_id == user_id) .order_by(Segment.created_at.asc()) ) .all() ) logger.info(f"历史段落: {len(segments_raw)} 条") if not segments_raw: logger.warning("没有对话段落,无需处理") return finally: db.close() finally: tunnel.stop() # 在本地生成新章节(不需要 DB) seg_tuples = [(row[0], row[1]) for row in segments_raw] new_chapters = generate_chapters_in_memory(seg_tuples, llm, batch_size, skip_llm_slots) # 构建对比结果 new_chapter_data = [] for ch in new_chapters: new_chapter_data.append({ "category": ch.category, "title": ch.title, "content_len": len(ch.content), "content_preview": (ch.content[:200] + "…") if len(ch.content) > 200 else ch.content, "content": ch.content, "order_index": ch.order_index, "source_segment_ids": ch.source_segment_ids, }) result = PreviewResult( user_id=user_id, phone=phone, nickname=nickname, generated_at=datetime.now(timezone.utc).isoformat(), old_chapters=old_chapter_data, new_chapters=new_chapter_data, ) # 保存 JSON os.makedirs(OUTPUT_DIR, exist_ok=True) json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json") with open(json_path, "w", encoding="utf-8") as f: json.dump(asdict(result), f, ensure_ascii=False, indent=2) logger.info(f"预览结果已保存: {json_path}") # 输出 Markdown md_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.md") md_lines = _build_comparison_markdown(result) with open(md_path, "w", encoding="utf-8") as f: f.write(md_lines) logger.info(f"对比 Markdown 已保存: {md_path}") # 同时打印到终端 print("\n" + md_lines) def _build_comparison_markdown(result: PreviewResult) -> str: lines = [] lines.append(f"# 回忆录重整对比 — {result.nickname} ({result.phone})") lines.append(f"\n生成时间: {result.generated_at}\n") # 总览表格 lines.append("## 总览\n") lines.append("| 阶段 | 旧标题 | 旧字数 | 新标题 | 新字数 | 变化 |") lines.append("|------|--------|--------|--------|--------|------|") old_map = {ch["category"]: ch for ch in result.old_chapters} new_map = {ch["category"]: ch for ch in result.new_chapters} all_stages = list(dict.fromkeys( [ch["category"] for ch in result.old_chapters] + [ch["category"] for ch in result.new_chapters] )) total_old = 0 total_new = 0 for stage in all_stages: old = old_map.get(stage) new = new_map.get(stage) old_title = old["title"] if old else "—" old_len = old["content_len"] if old else 0 new_title = new["title"] if new else "—" new_len = new["content_len"] if new else 0 total_old += old_len total_new += new_len diff = new_len - old_len diff_str = f"+{diff}" if diff >= 0 else str(diff) lines.append(f"| {stage} | {old_title} | {old_len} | {new_title} | {new_len} | {diff_str} |") diff_total = total_new - total_old diff_total_str = f"+{diff_total}" if diff_total >= 0 else str(diff_total) lines.append(f"| **合计** | | **{total_old}** | | **{total_new}** | **{diff_total_str}** |") # 各章节详细对比 lines.append("\n---\n") lines.append("## 各章节详细对比\n") for stage in all_stages: old = old_map.get(stage) new = new_map.get(stage) lines.append(f"### {stage}\n") lines.append("**旧内容预览:**\n") if old: lines.append(f"> {old['content_preview']}\n") else: lines.append("> (无)\n") lines.append("**新内容预览:**\n") if new: lines.append(f"> {new['content_preview']}\n") else: lines.append("> (无)\n") # 新章节完整内容 lines.append("\n---\n") lines.append("## 新章节完整内容\n") for ch in result.new_chapters: lines.append(f"### {ch['title']} ({ch['category']}, {ch['content_len']} 字)\n") lines.append(ch["content"]) lines.append("\n") return "\n".join(lines) # ── apply 命令 ──────────────────────────────────────────────── def cmd_apply(phone: str): json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json") if not os.path.exists(json_path): logger.error(f"未找到预览文件: {json_path}") logger.error("请先运行 preview 命令") sys.exit(1) with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) user_id = data["user_id"] new_chapters = data["new_chapters"] logger.info(f"将写入 {len(new_chapters)} 个新章节到用户 {data['nickname']} ({user_id})") # 确认 answer = input("\n确认写入远程数据库? (yes/no): ").strip().lower() if answer != "yes": logger.info("已取消") return tunnel = open_ssh_tunnel() try: db = make_session(tunnel) try: # 1. 旧章节 → inactive old_active = ( db.execute( select(Chapter).where( Chapter.user_id == user_id, Chapter.is_active == True ) ) .scalars() .all() ) for ch in old_active: ch.is_active = False logger.info(f"已将 {len(old_active)} 个旧章节标记为 inactive") # 2. 删除旧 MemoirState old_state = db.execute( select(MemoirState).where(MemoirState.user_id == user_id) ).scalar_one_or_none() if old_state: db.delete(old_state) logger.info("已删除旧 MemoirState") # 3. 创建新 MemoirState ds = default_state() db.add(MemoirState( id=str(uuid.uuid4()), user_id=user_id, stage_order=ds.stage_order, current_stage=ds.current_stage, covered_stages=ds.covered_stages, slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in ds.slots.items()}, )) # 4. 插入新章节(无 content/images;正文与配图写入 chapter_sections) last_chapter_id = None for ch_data in new_chapters: ch_id = str(uuid.uuid4()) chapter = Chapter( id=ch_id, user_id=user_id, title=ch_data["title"], order_index=ch_data["order_index"], status="completed", category=ch_data["category"], cover_image=None, is_new=True, source_segments=ch_data.get("source_segment_ids", []), ) db.add(chapter) db.flush() content = ch_data.get("content") or "" sections = split_narrative_to_sections(content) if not sections: db.add(ChapterSection( id=str(uuid.uuid4()).replace("-", "")[:32], chapter_id=ch_id, order_index=0, content=content.strip(), image=None, )) else: for order_idx, seg in enumerate(sections): db.add(ChapterSection( id=str(uuid.uuid4()).replace("-", "")[:32], chapter_id=ch_id, order_index=order_idx, content=(seg.get("content") or "").strip(), image=None, )) last_chapter_id = ch_id logger.info(f" 新建章节: [{ch_data['category']}] {ch_data['title']} — {ch_data['content_len']} 字") # 5. 更新 Book book = db.execute( select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) ).scalar_one_or_none() if not book: book = Book( id=str(uuid.uuid4()), user_id=user_id, title="我的回忆录", total_pages=0, total_words=0, cover_image_url=None, ) db.add(book) book.has_update = True if last_chapter_id: book.last_update_chapter_id = last_chapter_id # 6. 标记所有段落为已处理 segs = ( db.execute( select(Segment) .join(Conversation, Segment.conversation_id == Conversation.id) .where(Conversation.user_id == user_id) ) .scalars() .all() ) for seg in segs: seg.processed = True db.commit() logger.info("远程数据库写入完成!") finally: db.close() finally: tunnel.stop() # ── CLI 入口 ────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="重新整理用户历史对话为回忆录章节(远程预览+写入)") sub = parser.add_subparsers(dest="command", required=True) # preview p_preview = sub.add_parser("preview", help="预览:读取远程 DB,本地生成新章节,输出对比") p_preview.add_argument("--phone", required=True, help="用户手机号") p_preview.add_argument("--batch-size", type=int, default=5, help="每批段落数(默认 5)") p_preview.add_argument("--skip-llm-slots", action="store_true", help="跳过 LLM slot 提取") # apply p_apply = sub.add_parser("apply", help="写入:将 preview 结果写入远程 DB") p_apply.add_argument("--phone", required=True, help="用户手机号(需与 preview 一致)") args = parser.parse_args() if args.command == "preview": cmd_preview(phone=args.phone, batch_size=args.batch_size, skip_llm_slots=args.skip_llm_slots) elif args.command == "apply": cmd_apply(phone=args.phone) if __name__ == "__main__": main()