687 lines
25 KiB
Python
687 lines
25 KiB
Python
"""
|
||
重新整理用户历史对话为回忆录章节(远程预览 + 确认后写入)
|
||
|
||
用法:
|
||
cd api
|
||
|
||
# 第一步:预览(只读远程 DB,本地生成新章节,输出对比 Markdown)
|
||
python -m scripts.reprocess_user_memoir preview --phone 13800138000
|
||
|
||
# 第二步:确认后写入远程 DB
|
||
python -m scripts.reprocess_user_memoir apply --phone 13800138000
|
||
|
||
流程:
|
||
preview:
|
||
1. SSH 隧道连接远程 PostgreSQL
|
||
2. 读取用户现有章节 + 所有历史对话段落
|
||
3. 本地调用 LLM 生成新章节(不写入远程 DB)
|
||
4. 输出对比 Markdown 表格 + 保存结果到 JSON 文件
|
||
|
||
apply:
|
||
1. 读取上次 preview 保存的 JSON 文件
|
||
2. SSH 隧道连接远程 PostgreSQL
|
||
3. 旧章节 is_active=False,写入新章节
|
||
"""
|
||
import argparse
|
||
import json
|
||
import os
|
||
import sys
|
||
import uuid
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from typing import Dict, List, Optional
|
||
from dataclasses import dataclass, field, asdict
|
||
|
||
# 确保 api/ 目录在 sys.path 中
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
# 配置由 app.core.config.settings 统一加载
|
||
|
||
import socket
|
||
import subprocess
|
||
import signal
|
||
|
||
from sqlalchemy import create_engine, select
|
||
from sqlalchemy.orm import sessionmaker, Session
|
||
|
||
from app.core.db import Base
|
||
from app.features.conversation.models import Conversation, Segment
|
||
from app.features.memoir.models import Book, Chapter, ChapterSection, MemoirState
|
||
from app.features.user.models import User
|
||
from app.core.dependencies import get_llm_provider
|
||
from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
|
||
from app.agents.memoir.prompts import (
|
||
get_creative_title_prompt,
|
||
get_narrative_prompt,
|
||
get_state_extraction_prompt,
|
||
inject_image_placeholder_template,
|
||
STAGE_TO_ORDER,
|
||
)
|
||
from app.features.memoir.memoir_images.json_payload import extract_json_payload
|
||
from app.features.memoir.memoir_images.parser import split_narrative_to_sections
|
||
from app.core.logging import get_logger, setup_logging
|
||
|
||
setup_logging()
|
||
logger = get_logger(__name__)
|
||
|
||
# ── SSH / DB 配置 ──────────────────────────────────────────────
|
||
|
||
SSH_HOST = "1.15.29.57"
|
||
SSH_PORT = 22
|
||
SSH_USER = "root"
|
||
SSH_KEY_PATH = os.path.join(
|
||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||
"..", "certs", "key.crt",
|
||
)
|
||
|
||
REMOTE_PG_HOST = "127.0.0.1"
|
||
REMOTE_PG_PORT = 5432
|
||
PG_USER = "postgres"
|
||
PG_PASSWORD = "postgres"
|
||
PG_DATABASE = "life_echo"
|
||
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
|
||
|
||
# ── 关键字阶段检测 ────────────────────────────────────────────
|
||
|
||
STAGE_KEYWORDS = {
|
||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
||
}
|
||
|
||
|
||
def _detect_stage(text: str, fallback: str) -> str:
|
||
msg = text.lower()
|
||
for stage, keywords in STAGE_KEYWORDS.items():
|
||
if any(w in msg for w in keywords):
|
||
return stage
|
||
return fallback
|
||
|
||
|
||
# ── SSH 隧道 + DB 会话 ────────────────────────────────────────
|
||
|
||
|
||
class SshTunnel:
|
||
"""用 ssh -L 子进程建立隧道,兼容所有 paramiko 版本"""
|
||
|
||
def __init__(self, local_port: int = 15432):
|
||
self.local_port = local_port
|
||
self._proc: Optional[subprocess.Popen] = None
|
||
|
||
def start(self):
|
||
key_path = os.path.normpath(SSH_KEY_PATH)
|
||
cmd = [
|
||
"ssh", "-N", "-L",
|
||
f"{self.local_port}:{REMOTE_PG_HOST}:{REMOTE_PG_PORT}",
|
||
"-i", key_path,
|
||
"-p", str(SSH_PORT),
|
||
"-o", "StrictHostKeyChecking=no",
|
||
"-o", "ExitOnForwardFailure=yes",
|
||
"-o", "BatchMode=yes",
|
||
f"{SSH_USER}@{SSH_HOST}",
|
||
]
|
||
logger.info(f"SSH 隧道: {SSH_USER}@{SSH_HOST}:{SSH_PORT} -> 127.0.0.1:{self.local_port}, key={key_path}")
|
||
self._proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||
# 等待隧道端口可连接(最多 15 秒)
|
||
for attempt in range(30):
|
||
if self._proc.poll() is not None:
|
||
err = self._proc.stderr.read().decode() if self._proc.stderr else ""
|
||
raise RuntimeError(f"SSH 隧道进程已退出: {err}")
|
||
try:
|
||
sock = socket.create_connection(("127.0.0.1", self.local_port), timeout=1)
|
||
sock.close()
|
||
logger.info(f"SSH 隧道已建立, 本地端口: {self.local_port} (耗时 {attempt * 0.5:.1f}s)")
|
||
return
|
||
except (ConnectionRefusedError, OSError):
|
||
time.sleep(0.5)
|
||
# 超时
|
||
err = ""
|
||
if self._proc.poll() is not None and self._proc.stderr:
|
||
err = self._proc.stderr.read().decode()
|
||
raise RuntimeError(f"SSH 隧道端口 {self.local_port} 超时未就绪: {err}")
|
||
|
||
def stop(self):
|
||
if self._proc and self._proc.poll() is None:
|
||
self._proc.send_signal(signal.SIGTERM)
|
||
self._proc.wait(timeout=5)
|
||
logger.info("SSH 隧道已关闭")
|
||
|
||
@property
|
||
def local_bind_port(self) -> int:
|
||
return self.local_port
|
||
|
||
|
||
def open_ssh_tunnel() -> SshTunnel:
|
||
tunnel = SshTunnel()
|
||
tunnel.start()
|
||
return tunnel
|
||
|
||
|
||
def make_session(tunnel: SshTunnel) -> Session:
|
||
url = (
|
||
f"postgresql://{PG_USER}:{PG_PASSWORD}"
|
||
f"@127.0.0.1:{tunnel.local_bind_port}/{PG_DATABASE}"
|
||
)
|
||
engine = create_engine(url, pool_size=2, max_overflow=2)
|
||
return sessionmaker(bind=engine)()
|
||
|
||
|
||
# ── 数据结构:保存生成结果 ────────────────────────────────────
|
||
|
||
|
||
@dataclass
|
||
class GeneratedChapter:
|
||
category: str
|
||
title: str
|
||
content: str
|
||
order_index: int
|
||
source_segment_ids: List[str] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class PreviewResult:
|
||
user_id: str
|
||
phone: str
|
||
nickname: str
|
||
generated_at: str
|
||
old_chapters: List[dict] = field(default_factory=list) # {category, title, content_len, content_preview}
|
||
new_chapters: List[dict] = field(default_factory=list) # same shape + full content
|
||
|
||
|
||
# ── 核心:本地生成章节 ────────────────────────────────────────
|
||
|
||
|
||
def extract_slots_with_llm(llm, text: str, current_stage: str, stage_slots: dict):
|
||
try:
|
||
prompt = get_state_extraction_prompt(
|
||
user_message=text,
|
||
current_stage=current_stage,
|
||
stage_slots=stage_slots,
|
||
)
|
||
json_llm = llm.bind(
|
||
model_kwargs={"response_format": {"type": "json_object"}},
|
||
max_tokens=1024,
|
||
)
|
||
response = json_llm.invoke(prompt)
|
||
parsed = json.loads(extract_json_payload(response.content.strip()))
|
||
return parsed.get("detected_stage", current_stage), parsed.get("slots", {}) or {}
|
||
except Exception as e:
|
||
logger.warning(f"LLM slot 提取失败: {e}")
|
||
return current_stage, {}
|
||
|
||
|
||
def generate_chapters_in_memory(
|
||
segments: list, # list of (id, transcript_text)
|
||
llm,
|
||
batch_size: int,
|
||
skip_llm_slots: bool,
|
||
) -> List[GeneratedChapter]:
|
||
"""纯内存生成章节,不写任何 DB"""
|
||
state = default_state()
|
||
|
||
# 1. 阶段检测 & slot 提取(内存 state)
|
||
stage_to_segments: Dict[str, list] = {}
|
||
|
||
for idx, (seg_id, text) in enumerate(segments, 1):
|
||
if not text or not text.strip():
|
||
continue
|
||
|
||
detected_stage = _detect_stage(text, state.current_stage)
|
||
|
||
if not skip_llm_slots:
|
||
try:
|
||
detected_stage, extracted_slots = extract_slots_with_llm(
|
||
llm, text, state.current_stage, state.slots.get(detected_stage, {})
|
||
)
|
||
# 内存更新 state slots
|
||
for slot_name, snippet in extracted_slots.items():
|
||
stage_slots = state.slots.get(detected_stage, {})
|
||
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=[seg_id])
|
||
state.slots[detected_stage] = stage_slots
|
||
state.current_stage = detected_stage
|
||
except Exception as e:
|
||
logger.warning(f"段落 {idx} slot 提取失败: {e}")
|
||
|
||
stage_to_segments.setdefault(detected_stage, []).append((seg_id, text))
|
||
|
||
if idx % 20 == 0:
|
||
logger.info(f"阶段检测进度: {idx}/{len(segments)}")
|
||
|
||
for stage, segs in stage_to_segments.items():
|
||
logger.info(f"阶段 [{stage}]: {len(segs)} 条段落")
|
||
|
||
# 2. 按阶段分批生成
|
||
results: List[GeneratedChapter] = []
|
||
|
||
for stage, seg_list in stage_to_segments.items():
|
||
title = f"{stage} 回忆"
|
||
existing_content = ""
|
||
all_source_ids: List[str] = []
|
||
|
||
slot_snippets = {
|
||
key: value.snippet
|
||
for key, value in (state.slots.get(stage, {}) or {}).items()
|
||
if value.snippet
|
||
}
|
||
|
||
for i in range(0, len(seg_list), batch_size):
|
||
batch = seg_list[i : i + batch_size]
|
||
batch_num = i // batch_size + 1
|
||
total_batches = (len(seg_list) + batch_size - 1) // batch_size
|
||
logger.info(f"[{stage}] 处理第 {batch_num}/{total_batches} 批 ({len(batch)} 条)")
|
||
|
||
combined_text = "\n\n".join(text for _, text in batch)
|
||
source_ids = [sid for sid, _ in batch]
|
||
all_source_ids.extend(source_ids)
|
||
narrative = combined_text # fallback
|
||
|
||
try:
|
||
if not existing_content:
|
||
# 第一批 → 生成标题
|
||
title_prompt = get_creative_title_prompt(
|
||
stage=stage, emotion="neutral", slots=slot_snippets
|
||
)
|
||
title_response = llm.invoke(title_prompt)
|
||
title = title_response.content.strip().strip('"')
|
||
logger.info(f"[{stage}] 生成标题: {title}")
|
||
|
||
narrative_prompt = get_narrative_prompt(
|
||
stage=stage,
|
||
slots=slot_snippets,
|
||
new_content=combined_text,
|
||
existing_content=existing_content,
|
||
)
|
||
narrative_response = llm.invoke(narrative_prompt)
|
||
new_narrative = narrative_response.content.strip()
|
||
|
||
if existing_content:
|
||
narrative = f"{existing_content}\n\n{new_narrative}"
|
||
else:
|
||
narrative = new_narrative
|
||
except Exception as e:
|
||
logger.warning(f"[{stage}] LLM 生成失败: {e}")
|
||
if existing_content:
|
||
narrative = f"{existing_content}\n\n{combined_text}"
|
||
|
||
# 安全检查
|
||
if existing_content and len(narrative) < len(existing_content) * 0.8:
|
||
logger.warning(f"[{stage}] 内容长度异常, 回退追加模式")
|
||
narrative = f"{existing_content}\n\n{combined_text}"
|
||
|
||
existing_content = narrative
|
||
|
||
logger.info(f"[{stage}] 批次 {batch_num} 完成, 累计长度: {len(existing_content)} 字")
|
||
if i + batch_size < len(seg_list):
|
||
time.sleep(1)
|
||
|
||
# 入库前:占位符位置用正则匹配后拼上固定模板
|
||
content_to_save = inject_image_placeholder_template(existing_content)
|
||
results.append(GeneratedChapter(
|
||
category=stage,
|
||
title=title,
|
||
content=content_to_save,
|
||
order_index=STAGE_TO_ORDER.get(stage, 999),
|
||
source_segment_ids=all_source_ids,
|
||
))
|
||
|
||
return results
|
||
|
||
|
||
# ── preview 命令 ──────────────────────────────────────────────
|
||
|
||
|
||
def cmd_preview(phone: str, batch_size: int, skip_llm_slots: bool):
|
||
# LLM
|
||
llm = getattr(get_llm_provider(), "langchain_llm", None)
|
||
if not llm:
|
||
logger.error("LLM 未配置,请检查 .env 中的 DEEPSEEK_API_KEY")
|
||
sys.exit(1)
|
||
logger.info("LLM 就绪")
|
||
|
||
tunnel = open_ssh_tunnel()
|
||
try:
|
||
db = make_session(tunnel)
|
||
try:
|
||
# 找用户
|
||
user = db.execute(select(User).where(User.phone == phone)).scalar_one_or_none()
|
||
if not user:
|
||
logger.error(f"未找到手机号 {phone} 的用户")
|
||
sys.exit(1)
|
||
user_id = user.id
|
||
nickname = user.nickname
|
||
logger.info(f"用户: {nickname} (id={user_id})")
|
||
|
||
# 读取现有 active 章节(含 sections,正文从 sections 拼接)
|
||
from sqlalchemy.orm import joinedload
|
||
old_chapters = (
|
||
db.execute(
|
||
select(Chapter)
|
||
.where(Chapter.user_id == user_id, Chapter.is_active == True)
|
||
.options(joinedload(Chapter.sections))
|
||
.order_by(Chapter.order_index)
|
||
)
|
||
.unique()
|
||
.scalars()
|
||
.all()
|
||
)
|
||
old_chapter_data = []
|
||
for ch in old_chapters:
|
||
content = ""
|
||
if getattr(ch, "sections", None):
|
||
content = "\n\n".join(
|
||
(s.content or "").strip()
|
||
for s in sorted(ch.sections, key=lambda x: x.order_index)
|
||
if (s.content or "").strip()
|
||
)
|
||
content_len = len(content)
|
||
content_preview = (content[:200] + "…") if content_len > 200 else content
|
||
old_chapter_data.append({
|
||
"category": ch.category,
|
||
"title": ch.title,
|
||
"content_len": content_len,
|
||
"content_preview": content_preview,
|
||
})
|
||
logger.info(f"现有章节: {len(old_chapters)} 个")
|
||
|
||
# 读取所有段落
|
||
segments_raw = (
|
||
db.execute(
|
||
select(Segment.id, Segment.transcript_text)
|
||
.join(Conversation, Segment.conversation_id == Conversation.id)
|
||
.where(Conversation.user_id == user_id)
|
||
.order_by(Segment.created_at.asc())
|
||
)
|
||
.all()
|
||
)
|
||
logger.info(f"历史段落: {len(segments_raw)} 条")
|
||
if not segments_raw:
|
||
logger.warning("没有对话段落,无需处理")
|
||
return
|
||
|
||
finally:
|
||
db.close()
|
||
finally:
|
||
tunnel.stop()
|
||
|
||
# 在本地生成新章节(不需要 DB)
|
||
seg_tuples = [(row[0], row[1]) for row in segments_raw]
|
||
new_chapters = generate_chapters_in_memory(seg_tuples, llm, batch_size, skip_llm_slots)
|
||
|
||
# 构建对比结果
|
||
new_chapter_data = []
|
||
for ch in new_chapters:
|
||
new_chapter_data.append({
|
||
"category": ch.category,
|
||
"title": ch.title,
|
||
"content_len": len(ch.content),
|
||
"content_preview": (ch.content[:200] + "…") if len(ch.content) > 200 else ch.content,
|
||
"content": ch.content,
|
||
"order_index": ch.order_index,
|
||
"source_segment_ids": ch.source_segment_ids,
|
||
})
|
||
|
||
result = PreviewResult(
|
||
user_id=user_id,
|
||
phone=phone,
|
||
nickname=nickname,
|
||
generated_at=datetime.now(timezone.utc).isoformat(),
|
||
old_chapters=old_chapter_data,
|
||
new_chapters=new_chapter_data,
|
||
)
|
||
|
||
# 保存 JSON
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json")
|
||
with open(json_path, "w", encoding="utf-8") as f:
|
||
json.dump(asdict(result), f, ensure_ascii=False, indent=2)
|
||
logger.info(f"预览结果已保存: {json_path}")
|
||
|
||
# 输出 Markdown
|
||
md_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.md")
|
||
md_lines = _build_comparison_markdown(result)
|
||
with open(md_path, "w", encoding="utf-8") as f:
|
||
f.write(md_lines)
|
||
logger.info(f"对比 Markdown 已保存: {md_path}")
|
||
|
||
# 同时打印到终端
|
||
print("\n" + md_lines)
|
||
|
||
|
||
def _build_comparison_markdown(result: PreviewResult) -> str:
|
||
lines = []
|
||
lines.append(f"# 回忆录重整对比 — {result.nickname} ({result.phone})")
|
||
lines.append(f"\n生成时间: {result.generated_at}\n")
|
||
|
||
# 总览表格
|
||
lines.append("## 总览\n")
|
||
lines.append("| 阶段 | 旧标题 | 旧字数 | 新标题 | 新字数 | 变化 |")
|
||
lines.append("|------|--------|--------|--------|--------|------|")
|
||
|
||
old_map = {ch["category"]: ch for ch in result.old_chapters}
|
||
new_map = {ch["category"]: ch for ch in result.new_chapters}
|
||
all_stages = list(dict.fromkeys(
|
||
[ch["category"] for ch in result.old_chapters]
|
||
+ [ch["category"] for ch in result.new_chapters]
|
||
))
|
||
|
||
total_old = 0
|
||
total_new = 0
|
||
for stage in all_stages:
|
||
old = old_map.get(stage)
|
||
new = new_map.get(stage)
|
||
old_title = old["title"] if old else "—"
|
||
old_len = old["content_len"] if old else 0
|
||
new_title = new["title"] if new else "—"
|
||
new_len = new["content_len"] if new else 0
|
||
total_old += old_len
|
||
total_new += new_len
|
||
diff = new_len - old_len
|
||
diff_str = f"+{diff}" if diff >= 0 else str(diff)
|
||
lines.append(f"| {stage} | {old_title} | {old_len} | {new_title} | {new_len} | {diff_str} |")
|
||
|
||
diff_total = total_new - total_old
|
||
diff_total_str = f"+{diff_total}" if diff_total >= 0 else str(diff_total)
|
||
lines.append(f"| **合计** | | **{total_old}** | | **{total_new}** | **{diff_total_str}** |")
|
||
|
||
# 各章节详细对比
|
||
lines.append("\n---\n")
|
||
lines.append("## 各章节详细对比\n")
|
||
|
||
for stage in all_stages:
|
||
old = old_map.get(stage)
|
||
new = new_map.get(stage)
|
||
lines.append(f"### {stage}\n")
|
||
|
||
lines.append("**旧内容预览:**\n")
|
||
if old:
|
||
lines.append(f"> {old['content_preview']}\n")
|
||
else:
|
||
lines.append("> (无)\n")
|
||
|
||
lines.append("**新内容预览:**\n")
|
||
if new:
|
||
lines.append(f"> {new['content_preview']}\n")
|
||
else:
|
||
lines.append("> (无)\n")
|
||
|
||
# 新章节完整内容
|
||
lines.append("\n---\n")
|
||
lines.append("## 新章节完整内容\n")
|
||
for ch in result.new_chapters:
|
||
lines.append(f"### {ch['title']} ({ch['category']}, {ch['content_len']} 字)\n")
|
||
lines.append(ch["content"])
|
||
lines.append("\n")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
# ── apply 命令 ────────────────────────────────────────────────
|
||
|
||
|
||
def cmd_apply(phone: str):
|
||
json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json")
|
||
if not os.path.exists(json_path):
|
||
logger.error(f"未找到预览文件: {json_path}")
|
||
logger.error("请先运行 preview 命令")
|
||
sys.exit(1)
|
||
|
||
with open(json_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
user_id = data["user_id"]
|
||
new_chapters = data["new_chapters"]
|
||
logger.info(f"将写入 {len(new_chapters)} 个新章节到用户 {data['nickname']} ({user_id})")
|
||
|
||
# 确认
|
||
answer = input("\n确认写入远程数据库? (yes/no): ").strip().lower()
|
||
if answer != "yes":
|
||
logger.info("已取消")
|
||
return
|
||
|
||
tunnel = open_ssh_tunnel()
|
||
try:
|
||
db = make_session(tunnel)
|
||
try:
|
||
# 1. 旧章节 → inactive
|
||
old_active = (
|
||
db.execute(
|
||
select(Chapter).where(
|
||
Chapter.user_id == user_id, Chapter.is_active == True
|
||
)
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
for ch in old_active:
|
||
ch.is_active = False
|
||
logger.info(f"已将 {len(old_active)} 个旧章节标记为 inactive")
|
||
|
||
# 2. 删除旧 MemoirState
|
||
old_state = db.execute(
|
||
select(MemoirState).where(MemoirState.user_id == user_id)
|
||
).scalar_one_or_none()
|
||
if old_state:
|
||
db.delete(old_state)
|
||
logger.info("已删除旧 MemoirState")
|
||
|
||
# 3. 创建新 MemoirState
|
||
ds = default_state()
|
||
db.add(MemoirState(
|
||
id=str(uuid.uuid4()),
|
||
user_id=user_id,
|
||
stage_order=ds.stage_order,
|
||
current_stage=ds.current_stage,
|
||
covered_stages=ds.covered_stages,
|
||
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in ds.slots.items()},
|
||
))
|
||
|
||
# 4. 插入新章节(无 content/images;正文与配图写入 chapter_sections)
|
||
last_chapter_id = None
|
||
for ch_data in new_chapters:
|
||
ch_id = str(uuid.uuid4())
|
||
chapter = Chapter(
|
||
id=ch_id,
|
||
user_id=user_id,
|
||
title=ch_data["title"],
|
||
order_index=ch_data["order_index"],
|
||
status="completed",
|
||
category=ch_data["category"],
|
||
cover_image=None,
|
||
is_new=True,
|
||
source_segments=ch_data.get("source_segment_ids", []),
|
||
)
|
||
db.add(chapter)
|
||
db.flush()
|
||
content = ch_data.get("content") or ""
|
||
sections = split_narrative_to_sections(content)
|
||
if not sections:
|
||
db.add(ChapterSection(
|
||
id=str(uuid.uuid4()).replace("-", "")[:32],
|
||
chapter_id=ch_id,
|
||
order_index=0,
|
||
content=content.strip(),
|
||
image=None,
|
||
))
|
||
else:
|
||
for order_idx, seg in enumerate(sections):
|
||
db.add(ChapterSection(
|
||
id=str(uuid.uuid4()).replace("-", "")[:32],
|
||
chapter_id=ch_id,
|
||
order_index=order_idx,
|
||
content=(seg.get("content") or "").strip(),
|
||
image=None,
|
||
))
|
||
last_chapter_id = ch_id
|
||
logger.info(f" 新建章节: [{ch_data['category']}] {ch_data['title']} — {ch_data['content_len']} 字")
|
||
|
||
# 5. 更新 Book
|
||
book = db.execute(
|
||
select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
|
||
).scalar_one_or_none()
|
||
if not book:
|
||
book = Book(
|
||
id=str(uuid.uuid4()),
|
||
user_id=user_id,
|
||
title="我的回忆录",
|
||
total_pages=0,
|
||
total_words=0,
|
||
cover_image_url=None,
|
||
)
|
||
db.add(book)
|
||
book.has_update = True
|
||
if last_chapter_id:
|
||
book.last_update_chapter_id = last_chapter_id
|
||
|
||
# 6. 标记所有段落为已处理
|
||
segs = (
|
||
db.execute(
|
||
select(Segment)
|
||
.join(Conversation, Segment.conversation_id == Conversation.id)
|
||
.where(Conversation.user_id == user_id)
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
for seg in segs:
|
||
seg.processed = True
|
||
|
||
db.commit()
|
||
logger.info("远程数据库写入完成!")
|
||
|
||
finally:
|
||
db.close()
|
||
finally:
|
||
tunnel.stop()
|
||
|
||
|
||
# ── CLI 入口 ──────────────────────────────────────────────────
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="重新整理用户历史对话为回忆录章节(远程预览+写入)")
|
||
sub = parser.add_subparsers(dest="command", required=True)
|
||
|
||
# preview
|
||
p_preview = sub.add_parser("preview", help="预览:读取远程 DB,本地生成新章节,输出对比")
|
||
p_preview.add_argument("--phone", required=True, help="用户手机号")
|
||
p_preview.add_argument("--batch-size", type=int, default=5, help="每批段落数(默认 5)")
|
||
p_preview.add_argument("--skip-llm-slots", action="store_true", help="跳过 LLM slot 提取")
|
||
|
||
# apply
|
||
p_apply = sub.add_parser("apply", help="写入:将 preview 结果写入远程 DB")
|
||
p_apply.add_argument("--phone", required=True, help="用户手机号(需与 preview 一致)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.command == "preview":
|
||
cmd_preview(phone=args.phone, batch_size=args.batch_size, skip_llm_slots=args.skip_llm_slots)
|
||
elif args.command == "apply":
|
||
cmd_apply(phone=args.phone)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|