feat: 更新对话内容处理逻辑及新增用户回忆整理脚本
- 修改 get_narrative_prompt 函数,优化对话内容的叙述生成逻辑,确保新内容与已有内容自然衔接。 - 在 api/scripts 中新增 reprocess_user_memoir.py 脚本,用于整理用户历史对话为回忆录章节,支持远程预览和确认写入。 - 更新 .gitignore,添加对 certs/ 和 scripts/output/ 目录的忽略规则,确保不必要的文件不被跟踪。 - 在 memoir_tasks.py 中添加章节锁机制,防止并发写入同一章节,提升数据一致性和安全性。
This commit is contained in:
648
api/scripts/reprocess_user_memoir.py
Normal file
648
api/scripts/reprocess_user_memoir.py
Normal file
@@ -0,0 +1,648 @@
|
||||
"""
|
||||
重新整理用户历史对话为回忆录章节(远程预览 + 确认后写入)
|
||||
|
||||
用法:
|
||||
cd api
|
||||
|
||||
# 第一步:预览(只读远程 DB,本地生成新章节,输出对比 Markdown)
|
||||
python -m scripts.reprocess_user_memoir preview --phone 13800138000
|
||||
|
||||
# 第二步:确认后写入远程 DB
|
||||
python -m scripts.reprocess_user_memoir apply --phone 13800138000
|
||||
|
||||
流程:
|
||||
preview:
|
||||
1. SSH 隧道连接远程 PostgreSQL
|
||||
2. 读取用户现有章节 + 所有历史对话段落
|
||||
3. 本地调用 LLM 生成新章节(不写入远程 DB)
|
||||
4. 输出对比 Markdown 表格 + 保存结果到 JSON 文件
|
||||
|
||||
apply:
|
||||
1. 读取上次 preview 保存的 JSON 文件
|
||||
2. SSH 隧道连接远程 PostgreSQL
|
||||
3. 旧章节 is_active=False,写入新章节
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
# 确保 api/ 目录在 sys.path 中
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
import socket
|
||||
import subprocess
|
||||
import signal
|
||||
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
from database.models import Base, User, Conversation, Segment, Chapter, Book, MemoirState
|
||||
from services.llm_service import LLMService
|
||||
from agents.state_schema import MemoirStateSchema, SlotData, default_state
|
||||
from agents.prompts.memory_prompts import (
|
||||
get_creative_title_prompt,
|
||||
get_narrative_prompt,
|
||||
get_state_extraction_prompt,
|
||||
STAGE_TO_ORDER,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── SSH / DB 配置 ──────────────────────────────────────────────
|
||||
|
||||
SSH_HOST = "1.15.29.57"
|
||||
SSH_PORT = 22
|
||||
SSH_USER = "root"
|
||||
SSH_KEY_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"..", "certs", "key.crt",
|
||||
)
|
||||
|
||||
REMOTE_PG_HOST = "127.0.0.1"
|
||||
REMOTE_PG_PORT = 5432
|
||||
PG_USER = "postgres"
|
||||
PG_PASSWORD = "postgres"
|
||||
PG_DATABASE = "life_echo"
|
||||
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
|
||||
|
||||
# ── 关键字阶段检测 ────────────────────────────────────────────
|
||||
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
||||
}
|
||||
|
||||
|
||||
def _detect_stage(text: str, fallback: str) -> str:
|
||||
msg = text.lower()
|
||||
for stage, keywords in STAGE_KEYWORDS.items():
|
||||
if any(w in msg for w in keywords):
|
||||
return stage
|
||||
return fallback
|
||||
|
||||
|
||||
# ── SSH 隧道 + DB 会话 ────────────────────────────────────────
|
||||
|
||||
|
||||
class SshTunnel:
|
||||
"""用 ssh -L 子进程建立隧道,兼容所有 paramiko 版本"""
|
||||
|
||||
def __init__(self, local_port: int = 15432):
|
||||
self.local_port = local_port
|
||||
self._proc: Optional[subprocess.Popen] = None
|
||||
|
||||
def start(self):
|
||||
key_path = os.path.normpath(SSH_KEY_PATH)
|
||||
cmd = [
|
||||
"ssh", "-N", "-L",
|
||||
f"{self.local_port}:{REMOTE_PG_HOST}:{REMOTE_PG_PORT}",
|
||||
"-i", key_path,
|
||||
"-p", str(SSH_PORT),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "ExitOnForwardFailure=yes",
|
||||
"-o", "BatchMode=yes",
|
||||
f"{SSH_USER}@{SSH_HOST}",
|
||||
]
|
||||
logger.info(f"SSH 隧道: {SSH_USER}@{SSH_HOST}:{SSH_PORT} -> 127.0.0.1:{self.local_port}, key={key_path}")
|
||||
self._proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
# 等待隧道端口可连接(最多 15 秒)
|
||||
for attempt in range(30):
|
||||
if self._proc.poll() is not None:
|
||||
err = self._proc.stderr.read().decode() if self._proc.stderr else ""
|
||||
raise RuntimeError(f"SSH 隧道进程已退出: {err}")
|
||||
try:
|
||||
sock = socket.create_connection(("127.0.0.1", self.local_port), timeout=1)
|
||||
sock.close()
|
||||
logger.info(f"SSH 隧道已建立, 本地端口: {self.local_port} (耗时 {attempt * 0.5:.1f}s)")
|
||||
return
|
||||
except (ConnectionRefusedError, OSError):
|
||||
time.sleep(0.5)
|
||||
# 超时
|
||||
err = ""
|
||||
if self._proc.poll() is not None and self._proc.stderr:
|
||||
err = self._proc.stderr.read().decode()
|
||||
raise RuntimeError(f"SSH 隧道端口 {self.local_port} 超时未就绪: {err}")
|
||||
|
||||
def stop(self):
|
||||
if self._proc and self._proc.poll() is None:
|
||||
self._proc.send_signal(signal.SIGTERM)
|
||||
self._proc.wait(timeout=5)
|
||||
logger.info("SSH 隧道已关闭")
|
||||
|
||||
@property
|
||||
def local_bind_port(self) -> int:
|
||||
return self.local_port
|
||||
|
||||
|
||||
def open_ssh_tunnel() -> SshTunnel:
|
||||
tunnel = SshTunnel()
|
||||
tunnel.start()
|
||||
return tunnel
|
||||
|
||||
|
||||
def make_session(tunnel: SshTunnel) -> Session:
|
||||
url = (
|
||||
f"postgresql+psycopg://{PG_USER}:{PG_PASSWORD}"
|
||||
f"@127.0.0.1:{tunnel.local_bind_port}/{PG_DATABASE}"
|
||||
)
|
||||
engine = create_engine(url, pool_size=2, max_overflow=2)
|
||||
return sessionmaker(bind=engine)()
|
||||
|
||||
|
||||
# ── 数据结构:保存生成结果 ────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedChapter:
|
||||
category: str
|
||||
title: str
|
||||
content: str
|
||||
order_index: int
|
||||
source_segment_ids: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreviewResult:
|
||||
user_id: str
|
||||
phone: str
|
||||
nickname: str
|
||||
generated_at: str
|
||||
old_chapters: List[dict] = field(default_factory=list) # {category, title, content_len, content_preview}
|
||||
new_chapters: List[dict] = field(default_factory=list) # same shape + full content
|
||||
|
||||
|
||||
# ── 核心:本地生成章节 ────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_slots_with_llm(llm, text: str, current_stage: str, stage_slots: dict):
|
||||
try:
|
||||
prompt = get_state_extraction_prompt(
|
||||
user_message=text,
|
||||
current_stage=current_stage,
|
||||
stage_slots=stage_slots,
|
||||
)
|
||||
response = llm.invoke(prompt)
|
||||
parsed = json.loads(response.content.strip())
|
||||
return parsed.get("detected_stage", current_stage), parsed.get("slots", {}) or {}
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM slot 提取失败: {e}")
|
||||
return current_stage, {}
|
||||
|
||||
|
||||
def generate_chapters_in_memory(
|
||||
segments: list, # list of (id, transcript_text)
|
||||
llm,
|
||||
batch_size: int,
|
||||
skip_llm_slots: bool,
|
||||
) -> List[GeneratedChapter]:
|
||||
"""纯内存生成章节,不写任何 DB"""
|
||||
state = default_state()
|
||||
|
||||
# 1. 阶段检测 & slot 提取(内存 state)
|
||||
stage_to_segments: Dict[str, list] = {}
|
||||
|
||||
for idx, (seg_id, text) in enumerate(segments, 1):
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
detected_stage = _detect_stage(text, state.current_stage)
|
||||
|
||||
if not skip_llm_slots:
|
||||
try:
|
||||
detected_stage, extracted_slots = extract_slots_with_llm(
|
||||
llm, text, state.current_stage, state.slots.get(detected_stage, {})
|
||||
)
|
||||
# 内存更新 state slots
|
||||
for slot_name, snippet in extracted_slots.items():
|
||||
stage_slots = state.slots.get(detected_stage, {})
|
||||
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=[seg_id])
|
||||
state.slots[detected_stage] = stage_slots
|
||||
state.current_stage = detected_stage
|
||||
except Exception as e:
|
||||
logger.warning(f"段落 {idx} slot 提取失败: {e}")
|
||||
|
||||
stage_to_segments.setdefault(detected_stage, []).append((seg_id, text))
|
||||
|
||||
if idx % 20 == 0:
|
||||
logger.info(f"阶段检测进度: {idx}/{len(segments)}")
|
||||
|
||||
for stage, segs in stage_to_segments.items():
|
||||
logger.info(f"阶段 [{stage}]: {len(segs)} 条段落")
|
||||
|
||||
# 2. 按阶段分批生成
|
||||
results: List[GeneratedChapter] = []
|
||||
|
||||
for stage, seg_list in stage_to_segments.items():
|
||||
title = f"{stage} 回忆"
|
||||
existing_content = ""
|
||||
all_source_ids: List[str] = []
|
||||
|
||||
slot_snippets = {
|
||||
key: value.snippet
|
||||
for key, value in (state.slots.get(stage, {}) or {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
for i in range(0, len(seg_list), batch_size):
|
||||
batch = seg_list[i : i + batch_size]
|
||||
batch_num = i // batch_size + 1
|
||||
total_batches = (len(seg_list) + batch_size - 1) // batch_size
|
||||
logger.info(f"[{stage}] 处理第 {batch_num}/{total_batches} 批 ({len(batch)} 条)")
|
||||
|
||||
combined_text = "\n\n".join(text for _, text in batch)
|
||||
source_ids = [sid for sid, _ in batch]
|
||||
all_source_ids.extend(source_ids)
|
||||
narrative = combined_text # fallback
|
||||
|
||||
try:
|
||||
if not existing_content:
|
||||
# 第一批 → 生成标题
|
||||
title_prompt = get_creative_title_prompt(
|
||||
stage=stage, emotion="neutral", slots=slot_snippets
|
||||
)
|
||||
title_response = llm.invoke(title_prompt)
|
||||
title = title_response.content.strip().strip('"')
|
||||
logger.info(f"[{stage}] 生成标题: {title}")
|
||||
|
||||
narrative_prompt = get_narrative_prompt(
|
||||
stage=stage,
|
||||
slots=slot_snippets,
|
||||
new_content=combined_text,
|
||||
existing_content=existing_content,
|
||||
)
|
||||
narrative_response = llm.invoke(narrative_prompt)
|
||||
new_narrative = narrative_response.content.strip()
|
||||
|
||||
if existing_content:
|
||||
narrative = f"{existing_content}\n\n{new_narrative}"
|
||||
else:
|
||||
narrative = new_narrative
|
||||
except Exception as e:
|
||||
logger.warning(f"[{stage}] LLM 生成失败: {e}")
|
||||
if existing_content:
|
||||
narrative = f"{existing_content}\n\n{combined_text}"
|
||||
|
||||
# 安全检查
|
||||
if existing_content and len(narrative) < len(existing_content) * 0.8:
|
||||
logger.warning(f"[{stage}] 内容长度异常, 回退追加模式")
|
||||
narrative = f"{existing_content}\n\n{combined_text}"
|
||||
|
||||
existing_content = narrative
|
||||
|
||||
logger.info(f"[{stage}] 批次 {batch_num} 完成, 累计长度: {len(existing_content)} 字")
|
||||
if i + batch_size < len(seg_list):
|
||||
time.sleep(1)
|
||||
|
||||
results.append(GeneratedChapter(
|
||||
category=stage,
|
||||
title=title,
|
||||
content=existing_content,
|
||||
order_index=STAGE_TO_ORDER.get(stage, 999),
|
||||
source_segment_ids=all_source_ids,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── preview 命令 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def cmd_preview(phone: str, batch_size: int, skip_llm_slots: bool):
|
||||
# LLM
|
||||
llm = LLMService().get_llm()
|
||||
if not llm:
|
||||
logger.error("LLM 未配置,请检查 .env 中的 DEEPSEEK_API_KEY")
|
||||
sys.exit(1)
|
||||
logger.info("LLM 就绪")
|
||||
|
||||
tunnel = open_ssh_tunnel()
|
||||
try:
|
||||
db = make_session(tunnel)
|
||||
try:
|
||||
# 找用户
|
||||
user = db.execute(select(User).where(User.phone == phone)).scalar_one_or_none()
|
||||
if not user:
|
||||
logger.error(f"未找到手机号 {phone} 的用户")
|
||||
sys.exit(1)
|
||||
user_id = user.id
|
||||
nickname = user.nickname
|
||||
logger.info(f"用户: {nickname} (id={user_id})")
|
||||
|
||||
# 读取现有 active 章节
|
||||
old_chapters = (
|
||||
db.execute(
|
||||
select(Chapter)
|
||||
.where(Chapter.user_id == user_id, Chapter.is_active == True)
|
||||
.order_by(Chapter.order_index)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
old_chapter_data = []
|
||||
for ch in old_chapters:
|
||||
old_chapter_data.append({
|
||||
"category": ch.category,
|
||||
"title": ch.title,
|
||||
"content_len": len(ch.content) if ch.content else 0,
|
||||
"content_preview": (ch.content[:200] + "…") if ch.content and len(ch.content) > 200 else (ch.content or ""),
|
||||
})
|
||||
logger.info(f"现有章节: {len(old_chapters)} 个")
|
||||
|
||||
# 读取所有段落
|
||||
segments_raw = (
|
||||
db.execute(
|
||||
select(Segment.id, Segment.transcript_text)
|
||||
.join(Conversation, Segment.conversation_id == Conversation.id)
|
||||
.where(Conversation.user_id == user_id)
|
||||
.order_by(Segment.created_at.asc())
|
||||
)
|
||||
.all()
|
||||
)
|
||||
logger.info(f"历史段落: {len(segments_raw)} 条")
|
||||
if not segments_raw:
|
||||
logger.warning("没有对话段落,无需处理")
|
||||
return
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
finally:
|
||||
tunnel.stop()
|
||||
|
||||
# 在本地生成新章节(不需要 DB)
|
||||
seg_tuples = [(row[0], row[1]) for row in segments_raw]
|
||||
new_chapters = generate_chapters_in_memory(seg_tuples, llm, batch_size, skip_llm_slots)
|
||||
|
||||
# 构建对比结果
|
||||
new_chapter_data = []
|
||||
for ch in new_chapters:
|
||||
new_chapter_data.append({
|
||||
"category": ch.category,
|
||||
"title": ch.title,
|
||||
"content_len": len(ch.content),
|
||||
"content_preview": (ch.content[:200] + "…") if len(ch.content) > 200 else ch.content,
|
||||
"content": ch.content,
|
||||
"order_index": ch.order_index,
|
||||
"source_segment_ids": ch.source_segment_ids,
|
||||
})
|
||||
|
||||
result = PreviewResult(
|
||||
user_id=user_id,
|
||||
phone=phone,
|
||||
nickname=nickname,
|
||||
generated_at=datetime.now(timezone.utc).isoformat(),
|
||||
old_chapters=old_chapter_data,
|
||||
new_chapters=new_chapter_data,
|
||||
)
|
||||
|
||||
# 保存 JSON
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(result), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"预览结果已保存: {json_path}")
|
||||
|
||||
# 输出 Markdown
|
||||
md_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.md")
|
||||
md_lines = _build_comparison_markdown(result)
|
||||
with open(md_path, "w", encoding="utf-8") as f:
|
||||
f.write(md_lines)
|
||||
logger.info(f"对比 Markdown 已保存: {md_path}")
|
||||
|
||||
# 同时打印到终端
|
||||
print("\n" + md_lines)
|
||||
|
||||
|
||||
def _build_comparison_markdown(result: PreviewResult) -> str:
|
||||
lines = []
|
||||
lines.append(f"# 回忆录重整对比 — {result.nickname} ({result.phone})")
|
||||
lines.append(f"\n生成时间: {result.generated_at}\n")
|
||||
|
||||
# 总览表格
|
||||
lines.append("## 总览\n")
|
||||
lines.append("| 阶段 | 旧标题 | 旧字数 | 新标题 | 新字数 | 变化 |")
|
||||
lines.append("|------|--------|--------|--------|--------|------|")
|
||||
|
||||
old_map = {ch["category"]: ch for ch in result.old_chapters}
|
||||
new_map = {ch["category"]: ch for ch in result.new_chapters}
|
||||
all_stages = list(dict.fromkeys(
|
||||
[ch["category"] for ch in result.old_chapters]
|
||||
+ [ch["category"] for ch in result.new_chapters]
|
||||
))
|
||||
|
||||
total_old = 0
|
||||
total_new = 0
|
||||
for stage in all_stages:
|
||||
old = old_map.get(stage)
|
||||
new = new_map.get(stage)
|
||||
old_title = old["title"] if old else "—"
|
||||
old_len = old["content_len"] if old else 0
|
||||
new_title = new["title"] if new else "—"
|
||||
new_len = new["content_len"] if new else 0
|
||||
total_old += old_len
|
||||
total_new += new_len
|
||||
diff = new_len - old_len
|
||||
diff_str = f"+{diff}" if diff >= 0 else str(diff)
|
||||
lines.append(f"| {stage} | {old_title} | {old_len} | {new_title} | {new_len} | {diff_str} |")
|
||||
|
||||
diff_total = total_new - total_old
|
||||
diff_total_str = f"+{diff_total}" if diff_total >= 0 else str(diff_total)
|
||||
lines.append(f"| **合计** | | **{total_old}** | | **{total_new}** | **{diff_total_str}** |")
|
||||
|
||||
# 各章节详细对比
|
||||
lines.append("\n---\n")
|
||||
lines.append("## 各章节详细对比\n")
|
||||
|
||||
for stage in all_stages:
|
||||
old = old_map.get(stage)
|
||||
new = new_map.get(stage)
|
||||
lines.append(f"### {stage}\n")
|
||||
|
||||
lines.append("**旧内容预览:**\n")
|
||||
if old:
|
||||
lines.append(f"> {old['content_preview']}\n")
|
||||
else:
|
||||
lines.append("> (无)\n")
|
||||
|
||||
lines.append("**新内容预览:**\n")
|
||||
if new:
|
||||
lines.append(f"> {new['content_preview']}\n")
|
||||
else:
|
||||
lines.append("> (无)\n")
|
||||
|
||||
# 新章节完整内容
|
||||
lines.append("\n---\n")
|
||||
lines.append("## 新章节完整内容\n")
|
||||
for ch in result.new_chapters:
|
||||
lines.append(f"### {ch['title']} ({ch['category']}, {ch['content_len']} 字)\n")
|
||||
lines.append(ch["content"])
|
||||
lines.append("\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── apply 命令 ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def cmd_apply(phone: str):
|
||||
json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json")
|
||||
if not os.path.exists(json_path):
|
||||
logger.error(f"未找到预览文件: {json_path}")
|
||||
logger.error("请先运行 preview 命令")
|
||||
sys.exit(1)
|
||||
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
user_id = data["user_id"]
|
||||
new_chapters = data["new_chapters"]
|
||||
logger.info(f"将写入 {len(new_chapters)} 个新章节到用户 {data['nickname']} ({user_id})")
|
||||
|
||||
# 确认
|
||||
answer = input("\n确认写入远程数据库? (yes/no): ").strip().lower()
|
||||
if answer != "yes":
|
||||
logger.info("已取消")
|
||||
return
|
||||
|
||||
tunnel = open_ssh_tunnel()
|
||||
try:
|
||||
db = make_session(tunnel)
|
||||
try:
|
||||
# 1. 旧章节 → inactive
|
||||
old_active = (
|
||||
db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.user_id == user_id, Chapter.is_active == True
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
for ch in old_active:
|
||||
ch.is_active = False
|
||||
logger.info(f"已将 {len(old_active)} 个旧章节标记为 inactive")
|
||||
|
||||
# 2. 删除旧 MemoirState
|
||||
old_state = db.execute(
|
||||
select(MemoirState).where(MemoirState.user_id == user_id)
|
||||
).scalar_one_or_none()
|
||||
if old_state:
|
||||
db.delete(old_state)
|
||||
logger.info("已删除旧 MemoirState")
|
||||
|
||||
# 3. 创建新 MemoirState
|
||||
ds = default_state()
|
||||
db.add(MemoirState(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
stage_order=ds.stage_order,
|
||||
current_stage=ds.current_stage,
|
||||
covered_stages=ds.covered_stages,
|
||||
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in ds.slots.items()},
|
||||
))
|
||||
|
||||
# 4. 插入新章节
|
||||
last_chapter_id = None
|
||||
for ch_data in new_chapters:
|
||||
ch_id = str(uuid.uuid4())
|
||||
chapter = Chapter(
|
||||
id=ch_id,
|
||||
user_id=user_id,
|
||||
title=ch_data["title"],
|
||||
content=ch_data["content"],
|
||||
order_index=ch_data["order_index"],
|
||||
status="completed",
|
||||
category=ch_data["category"],
|
||||
images=[],
|
||||
is_new=True,
|
||||
source_segments=ch_data.get("source_segment_ids", []),
|
||||
)
|
||||
db.add(chapter)
|
||||
last_chapter_id = ch_id
|
||||
logger.info(f" 新建章节: [{ch_data['category']}] {ch_data['title']} — {ch_data['content_len']} 字")
|
||||
|
||||
# 5. 更新 Book
|
||||
book = db.execute(
|
||||
select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
|
||||
).scalar_one_or_none()
|
||||
if not book:
|
||||
book = Book(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title="我的回忆录",
|
||||
total_pages=0,
|
||||
total_words=0,
|
||||
cover_image_url=None,
|
||||
)
|
||||
db.add(book)
|
||||
book.has_update = True
|
||||
if last_chapter_id:
|
||||
book.last_update_chapter_id = last_chapter_id
|
||||
|
||||
# 6. 标记所有段落为已处理
|
||||
segs = (
|
||||
db.execute(
|
||||
select(Segment)
|
||||
.join(Conversation, Segment.conversation_id == Conversation.id)
|
||||
.where(Conversation.user_id == user_id)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
for seg in segs:
|
||||
seg.processed = True
|
||||
|
||||
db.commit()
|
||||
logger.info("远程数据库写入完成!")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
finally:
|
||||
tunnel.stop()
|
||||
|
||||
|
||||
# ── CLI 入口 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="重新整理用户历史对话为回忆录章节(远程预览+写入)")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# preview
|
||||
p_preview = sub.add_parser("preview", help="预览:读取远程 DB,本地生成新章节,输出对比")
|
||||
p_preview.add_argument("--phone", required=True, help="用户手机号")
|
||||
p_preview.add_argument("--batch-size", type=int, default=5, help="每批段落数(默认 5)")
|
||||
p_preview.add_argument("--skip-llm-slots", action="store_true", help="跳过 LLM slot 提取")
|
||||
|
||||
# apply
|
||||
p_apply = sub.add_parser("apply", help="写入:将 preview 结果写入远程 DB")
|
||||
p_apply.add_argument("--phone", required=True, help="用户手机号(需与 preview 一致)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "preview":
|
||||
cmd_preview(phone=args.phone, batch_size=args.batch_size, skip_llm_slots=args.skip_llm_slots)
|
||||
elif args.command == "apply":
|
||||
cmd_apply(phone=args.phone)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user