Files
life-echo/api/scripts/reprocess_user_memoir.py

682 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
重新整理用户历史对话为回忆录章节(远程预览 + 确认后写入)
用法:
cd api
# 第一步:预览(只读远程 DB本地生成新章节输出对比 Markdown
python -m scripts.reprocess_user_memoir preview --phone 13800138000
# 第二步:确认后写入远程 DB
python -m scripts.reprocess_user_memoir apply --phone 13800138000
流程:
preview:
1. SSH 隧道连接远程 PostgreSQL
2. 读取用户现有章节 + 所有历史对话段落
3. 本地调用 LLM 生成新章节(不写入远程 DB
4. 输出对比 Markdown 表格 + 保存结果到 JSON 文件
apply:
1. 读取上次 preview 保存的 JSON 文件
2. SSH 隧道连接远程 PostgreSQL
3. 旧章节 is_active=False写入新章节
"""
import argparse
import json
import os
import sys
import uuid
import time
from datetime import datetime, timezone
from typing import Dict, List, Optional
from dataclasses import dataclass, field, asdict
# 确保 api/ 目录在 sys.path 中
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 配置由 app.core.config.settings 统一加载
import socket
import subprocess
import signal
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker, Session
from app.core.db import Base
from app.features.conversation.models import Conversation, Segment
from app.features.memoir.models import Book, Chapter, ChapterSection, MemoirState
from app.features.user.models import User
from app.core.dependencies import get_llm_provider
from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
from app.agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
inject_image_placeholder_template,
STAGE_TO_ORDER,
)
from app.features.memoir.memoir_images.parser import split_narrative_to_sections
from app.core.logging import get_logger, setup_logging
setup_logging()
logger = get_logger(__name__)
# ── SSH / DB 配置 ──────────────────────────────────────────────
SSH_HOST = "1.15.29.57"
SSH_PORT = 22
SSH_USER = "root"
SSH_KEY_PATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"..", "certs", "key.crt",
)
REMOTE_PG_HOST = "127.0.0.1"
REMOTE_PG_PORT = 5432
PG_USER = "postgres"
PG_PASSWORD = "postgres"
PG_DATABASE = "life_echo"
OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
# ── 关键字阶段检测 ────────────────────────────────────────────
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
def _detect_stage(text: str, fallback: str) -> str:
msg = text.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(w in msg for w in keywords):
return stage
return fallback
# ── SSH 隧道 + DB 会话 ────────────────────────────────────────
class SshTunnel:
"""用 ssh -L 子进程建立隧道,兼容所有 paramiko 版本"""
def __init__(self, local_port: int = 15432):
self.local_port = local_port
self._proc: Optional[subprocess.Popen] = None
def start(self):
key_path = os.path.normpath(SSH_KEY_PATH)
cmd = [
"ssh", "-N", "-L",
f"{self.local_port}:{REMOTE_PG_HOST}:{REMOTE_PG_PORT}",
"-i", key_path,
"-p", str(SSH_PORT),
"-o", "StrictHostKeyChecking=no",
"-o", "ExitOnForwardFailure=yes",
"-o", "BatchMode=yes",
f"{SSH_USER}@{SSH_HOST}",
]
logger.info(f"SSH 隧道: {SSH_USER}@{SSH_HOST}:{SSH_PORT} -> 127.0.0.1:{self.local_port}, key={key_path}")
self._proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
# 等待隧道端口可连接(最多 15 秒)
for attempt in range(30):
if self._proc.poll() is not None:
err = self._proc.stderr.read().decode() if self._proc.stderr else ""
raise RuntimeError(f"SSH 隧道进程已退出: {err}")
try:
sock = socket.create_connection(("127.0.0.1", self.local_port), timeout=1)
sock.close()
logger.info(f"SSH 隧道已建立, 本地端口: {self.local_port} (耗时 {attempt * 0.5:.1f}s)")
return
except (ConnectionRefusedError, OSError):
time.sleep(0.5)
# 超时
err = ""
if self._proc.poll() is not None and self._proc.stderr:
err = self._proc.stderr.read().decode()
raise RuntimeError(f"SSH 隧道端口 {self.local_port} 超时未就绪: {err}")
def stop(self):
if self._proc and self._proc.poll() is None:
self._proc.send_signal(signal.SIGTERM)
self._proc.wait(timeout=5)
logger.info("SSH 隧道已关闭")
@property
def local_bind_port(self) -> int:
return self.local_port
def open_ssh_tunnel() -> SshTunnel:
tunnel = SshTunnel()
tunnel.start()
return tunnel
def make_session(tunnel: SshTunnel) -> Session:
url = (
f"postgresql://{PG_USER}:{PG_PASSWORD}"
f"@127.0.0.1:{tunnel.local_bind_port}/{PG_DATABASE}"
)
engine = create_engine(url, pool_size=2, max_overflow=2)
return sessionmaker(bind=engine)()
# ── 数据结构:保存生成结果 ────────────────────────────────────
@dataclass
class GeneratedChapter:
category: str
title: str
content: str
order_index: int
source_segment_ids: List[str] = field(default_factory=list)
@dataclass
class PreviewResult:
user_id: str
phone: str
nickname: str
generated_at: str
old_chapters: List[dict] = field(default_factory=list) # {category, title, content_len, content_preview}
new_chapters: List[dict] = field(default_factory=list) # same shape + full content
# ── 核心:本地生成章节 ────────────────────────────────────────
def extract_slots_with_llm(llm, text: str, current_stage: str, stage_slots: dict):
try:
prompt = get_state_extraction_prompt(
user_message=text,
current_stage=current_stage,
stage_slots=stage_slots,
)
response = llm.invoke(prompt)
parsed = json.loads(response.content.strip())
return parsed.get("detected_stage", current_stage), parsed.get("slots", {}) or {}
except Exception as e:
logger.warning(f"LLM slot 提取失败: {e}")
return current_stage, {}
def generate_chapters_in_memory(
segments: list, # list of (id, transcript_text)
llm,
batch_size: int,
skip_llm_slots: bool,
) -> List[GeneratedChapter]:
"""纯内存生成章节,不写任何 DB"""
state = default_state()
# 1. 阶段检测 & slot 提取(内存 state
stage_to_segments: Dict[str, list] = {}
for idx, (seg_id, text) in enumerate(segments, 1):
if not text or not text.strip():
continue
detected_stage = _detect_stage(text, state.current_stage)
if not skip_llm_slots:
try:
detected_stage, extracted_slots = extract_slots_with_llm(
llm, text, state.current_stage, state.slots.get(detected_stage, {})
)
# 内存更新 state slots
for slot_name, snippet in extracted_slots.items():
stage_slots = state.slots.get(detected_stage, {})
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=[seg_id])
state.slots[detected_stage] = stage_slots
state.current_stage = detected_stage
except Exception as e:
logger.warning(f"段落 {idx} slot 提取失败: {e}")
stage_to_segments.setdefault(detected_stage, []).append((seg_id, text))
if idx % 20 == 0:
logger.info(f"阶段检测进度: {idx}/{len(segments)}")
for stage, segs in stage_to_segments.items():
logger.info(f"阶段 [{stage}]: {len(segs)} 条段落")
# 2. 按阶段分批生成
results: List[GeneratedChapter] = []
for stage, seg_list in stage_to_segments.items():
title = f"{stage} 回忆"
existing_content = ""
all_source_ids: List[str] = []
slot_snippets = {
key: value.snippet
for key, value in (state.slots.get(stage, {}) or {}).items()
if value.snippet
}
for i in range(0, len(seg_list), batch_size):
batch = seg_list[i : i + batch_size]
batch_num = i // batch_size + 1
total_batches = (len(seg_list) + batch_size - 1) // batch_size
logger.info(f"[{stage}] 处理第 {batch_num}/{total_batches} 批 ({len(batch)} 条)")
combined_text = "\n\n".join(text for _, text in batch)
source_ids = [sid for sid, _ in batch]
all_source_ids.extend(source_ids)
narrative = combined_text # fallback
try:
if not existing_content:
# 第一批 → 生成标题
title_prompt = get_creative_title_prompt(
stage=stage, emotion="neutral", slots=slot_snippets
)
title_response = llm.invoke(title_prompt)
title = title_response.content.strip().strip('"')
logger.info(f"[{stage}] 生成标题: {title}")
narrative_prompt = get_narrative_prompt(
stage=stage,
slots=slot_snippets,
new_content=combined_text,
existing_content=existing_content,
)
narrative_response = llm.invoke(narrative_prompt)
new_narrative = narrative_response.content.strip()
if existing_content:
narrative = f"{existing_content}\n\n{new_narrative}"
else:
narrative = new_narrative
except Exception as e:
logger.warning(f"[{stage}] LLM 生成失败: {e}")
if existing_content:
narrative = f"{existing_content}\n\n{combined_text}"
# 安全检查
if existing_content and len(narrative) < len(existing_content) * 0.8:
logger.warning(f"[{stage}] 内容长度异常, 回退追加模式")
narrative = f"{existing_content}\n\n{combined_text}"
existing_content = narrative
logger.info(f"[{stage}] 批次 {batch_num} 完成, 累计长度: {len(existing_content)}")
if i + batch_size < len(seg_list):
time.sleep(1)
# 入库前:占位符位置用正则匹配后拼上固定模板
content_to_save = inject_image_placeholder_template(existing_content)
results.append(GeneratedChapter(
category=stage,
title=title,
content=content_to_save,
order_index=STAGE_TO_ORDER.get(stage, 999),
source_segment_ids=all_source_ids,
))
return results
# ── preview 命令 ──────────────────────────────────────────────
def cmd_preview(phone: str, batch_size: int, skip_llm_slots: bool):
# LLM
llm = getattr(get_llm_provider(), "langchain_llm", None)
if not llm:
logger.error("LLM 未配置,请检查 .env 中的 DEEPSEEK_API_KEY")
sys.exit(1)
logger.info("LLM 就绪")
tunnel = open_ssh_tunnel()
try:
db = make_session(tunnel)
try:
# 找用户
user = db.execute(select(User).where(User.phone == phone)).scalar_one_or_none()
if not user:
logger.error(f"未找到手机号 {phone} 的用户")
sys.exit(1)
user_id = user.id
nickname = user.nickname
logger.info(f"用户: {nickname} (id={user_id})")
# 读取现有 active 章节(含 sections正文从 sections 拼接)
from sqlalchemy.orm import joinedload
old_chapters = (
db.execute(
select(Chapter)
.where(Chapter.user_id == user_id, Chapter.is_active == True)
.options(joinedload(Chapter.sections))
.order_by(Chapter.order_index)
)
.unique()
.scalars()
.all()
)
old_chapter_data = []
for ch in old_chapters:
content = ""
if getattr(ch, "sections", None):
content = "\n\n".join(
(s.content or "").strip()
for s in sorted(ch.sections, key=lambda x: x.order_index)
if (s.content or "").strip()
)
content_len = len(content)
content_preview = (content[:200] + "") if content_len > 200 else content
old_chapter_data.append({
"category": ch.category,
"title": ch.title,
"content_len": content_len,
"content_preview": content_preview,
})
logger.info(f"现有章节: {len(old_chapters)}")
# 读取所有段落
segments_raw = (
db.execute(
select(Segment.id, Segment.transcript_text)
.join(Conversation, Segment.conversation_id == Conversation.id)
.where(Conversation.user_id == user_id)
.order_by(Segment.created_at.asc())
)
.all()
)
logger.info(f"历史段落: {len(segments_raw)}")
if not segments_raw:
logger.warning("没有对话段落,无需处理")
return
finally:
db.close()
finally:
tunnel.stop()
# 在本地生成新章节(不需要 DB
seg_tuples = [(row[0], row[1]) for row in segments_raw]
new_chapters = generate_chapters_in_memory(seg_tuples, llm, batch_size, skip_llm_slots)
# 构建对比结果
new_chapter_data = []
for ch in new_chapters:
new_chapter_data.append({
"category": ch.category,
"title": ch.title,
"content_len": len(ch.content),
"content_preview": (ch.content[:200] + "") if len(ch.content) > 200 else ch.content,
"content": ch.content,
"order_index": ch.order_index,
"source_segment_ids": ch.source_segment_ids,
})
result = PreviewResult(
user_id=user_id,
phone=phone,
nickname=nickname,
generated_at=datetime.now(timezone.utc).isoformat(),
old_chapters=old_chapter_data,
new_chapters=new_chapter_data,
)
# 保存 JSON
os.makedirs(OUTPUT_DIR, exist_ok=True)
json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(asdict(result), f, ensure_ascii=False, indent=2)
logger.info(f"预览结果已保存: {json_path}")
# 输出 Markdown
md_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.md")
md_lines = _build_comparison_markdown(result)
with open(md_path, "w", encoding="utf-8") as f:
f.write(md_lines)
logger.info(f"对比 Markdown 已保存: {md_path}")
# 同时打印到终端
print("\n" + md_lines)
def _build_comparison_markdown(result: PreviewResult) -> str:
lines = []
lines.append(f"# 回忆录重整对比 — {result.nickname} ({result.phone})")
lines.append(f"\n生成时间: {result.generated_at}\n")
# 总览表格
lines.append("## 总览\n")
lines.append("| 阶段 | 旧标题 | 旧字数 | 新标题 | 新字数 | 变化 |")
lines.append("|------|--------|--------|--------|--------|------|")
old_map = {ch["category"]: ch for ch in result.old_chapters}
new_map = {ch["category"]: ch for ch in result.new_chapters}
all_stages = list(dict.fromkeys(
[ch["category"] for ch in result.old_chapters]
+ [ch["category"] for ch in result.new_chapters]
))
total_old = 0
total_new = 0
for stage in all_stages:
old = old_map.get(stage)
new = new_map.get(stage)
old_title = old["title"] if old else ""
old_len = old["content_len"] if old else 0
new_title = new["title"] if new else ""
new_len = new["content_len"] if new else 0
total_old += old_len
total_new += new_len
diff = new_len - old_len
diff_str = f"+{diff}" if diff >= 0 else str(diff)
lines.append(f"| {stage} | {old_title} | {old_len} | {new_title} | {new_len} | {diff_str} |")
diff_total = total_new - total_old
diff_total_str = f"+{diff_total}" if diff_total >= 0 else str(diff_total)
lines.append(f"| **合计** | | **{total_old}** | | **{total_new}** | **{diff_total_str}** |")
# 各章节详细对比
lines.append("\n---\n")
lines.append("## 各章节详细对比\n")
for stage in all_stages:
old = old_map.get(stage)
new = new_map.get(stage)
lines.append(f"### {stage}\n")
lines.append("**旧内容预览:**\n")
if old:
lines.append(f"> {old['content_preview']}\n")
else:
lines.append("> (无)\n")
lines.append("**新内容预览:**\n")
if new:
lines.append(f"> {new['content_preview']}\n")
else:
lines.append("> (无)\n")
# 新章节完整内容
lines.append("\n---\n")
lines.append("## 新章节完整内容\n")
for ch in result.new_chapters:
lines.append(f"### {ch['title']} ({ch['category']}, {ch['content_len']} 字)\n")
lines.append(ch["content"])
lines.append("\n")
return "\n".join(lines)
# ── apply 命令 ────────────────────────────────────────────────
def cmd_apply(phone: str):
json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json")
if not os.path.exists(json_path):
logger.error(f"未找到预览文件: {json_path}")
logger.error("请先运行 preview 命令")
sys.exit(1)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
user_id = data["user_id"]
new_chapters = data["new_chapters"]
logger.info(f"将写入 {len(new_chapters)} 个新章节到用户 {data['nickname']} ({user_id})")
# 确认
answer = input("\n确认写入远程数据库? (yes/no): ").strip().lower()
if answer != "yes":
logger.info("已取消")
return
tunnel = open_ssh_tunnel()
try:
db = make_session(tunnel)
try:
# 1. 旧章节 → inactive
old_active = (
db.execute(
select(Chapter).where(
Chapter.user_id == user_id, Chapter.is_active == True
)
)
.scalars()
.all()
)
for ch in old_active:
ch.is_active = False
logger.info(f"已将 {len(old_active)} 个旧章节标记为 inactive")
# 2. 删除旧 MemoirState
old_state = db.execute(
select(MemoirState).where(MemoirState.user_id == user_id)
).scalar_one_or_none()
if old_state:
db.delete(old_state)
logger.info("已删除旧 MemoirState")
# 3. 创建新 MemoirState
ds = default_state()
db.add(MemoirState(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=ds.stage_order,
current_stage=ds.current_stage,
covered_stages=ds.covered_stages,
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in ds.slots.items()},
))
# 4. 插入新章节(无 content/images正文与配图写入 chapter_sections
last_chapter_id = None
for ch_data in new_chapters:
ch_id = str(uuid.uuid4())
chapter = Chapter(
id=ch_id,
user_id=user_id,
title=ch_data["title"],
order_index=ch_data["order_index"],
status="completed",
category=ch_data["category"],
cover_image=None,
is_new=True,
source_segments=ch_data.get("source_segment_ids", []),
)
db.add(chapter)
db.flush()
content = ch_data.get("content") or ""
sections = split_narrative_to_sections(content)
if not sections:
db.add(ChapterSection(
id=str(uuid.uuid4()).replace("-", "")[:32],
chapter_id=ch_id,
order_index=0,
content=content.strip(),
image=None,
))
else:
for order_idx, seg in enumerate(sections):
db.add(ChapterSection(
id=str(uuid.uuid4()).replace("-", "")[:32],
chapter_id=ch_id,
order_index=order_idx,
content=(seg.get("content") or "").strip(),
image=None,
))
last_chapter_id = ch_id
logger.info(f" 新建章节: [{ch_data['category']}] {ch_data['title']}{ch_data['content_len']}")
# 5. 更新 Book
book = db.execute(
select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
).scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
if last_chapter_id:
book.last_update_chapter_id = last_chapter_id
# 6. 标记所有段落为已处理
segs = (
db.execute(
select(Segment)
.join(Conversation, Segment.conversation_id == Conversation.id)
.where(Conversation.user_id == user_id)
)
.scalars()
.all()
)
for seg in segs:
seg.processed = True
db.commit()
logger.info("远程数据库写入完成!")
finally:
db.close()
finally:
tunnel.stop()
# ── CLI 入口 ──────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="重新整理用户历史对话为回忆录章节(远程预览+写入)")
sub = parser.add_subparsers(dest="command", required=True)
# preview
p_preview = sub.add_parser("preview", help="预览:读取远程 DB本地生成新章节输出对比")
p_preview.add_argument("--phone", required=True, help="用户手机号")
p_preview.add_argument("--batch-size", type=int, default=5, help="每批段落数(默认 5")
p_preview.add_argument("--skip-llm-slots", action="store_true", help="跳过 LLM slot 提取")
# apply
p_apply = sub.add_parser("apply", help="写入:将 preview 结果写入远程 DB")
p_apply.add_argument("--phone", required=True, help="用户手机号(需与 preview 一致)")
args = parser.parse_args()
if args.command == "preview":
cmd_preview(phone=args.phone, batch_size=args.batch_size, skip_llm_slots=args.skip_llm_slots)
elif args.command == "apply":
cmd_apply(phone=args.phone)
if __name__ == "__main__":
main()