Files
life-echo/api/app/features/memoir/repo.py
2026-03-20 15:15:35 +08:00

419 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Memoir repository — Book, Chapter, MemoirState data access."""
import uuid
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, joinedload
from app.core.db import utc_now
from app.features.asset.models import Asset
from app.features.memoir.asset_resolver import collect_asset_ids_for_chapter
from app.features.memoir.chapter_markdown_compose import (
materialize_chapter_markdown_from_loaded_chapter,
)
from app.features.memoir.models import (
Book,
Chapter,
ChapterCoverIntent,
ChapterStoryLink,
ChapterVersion,
MemoirState,
)
from app.features.story.models import Story
async def get_current_book(user_id: str, db: AsyncSession) -> Book | None:
stmt = (
select(Book)
.where(Book.user_id == user_id)
.order_by(Book.updated_at.desc())
.limit(1)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_chapters_for_memoir_list(
user_id: str,
db: AsyncSession,
*,
active_only: bool = True,
is_new_only: bool | None = None,
) -> list[Chapter]:
"""列表/详情stories-first预加载 story_links 与 images。"""
stmt = (
select(Chapter)
.where(Chapter.user_id == user_id)
.options(
joinedload(Chapter.images),
joinedload(Chapter.story_links).joinedload(ChapterStoryLink.story),
)
.order_by(Chapter.order_index)
)
if active_only:
stmt = stmt.where(Chapter.is_active == True) # noqa: E712
if is_new_only is True:
stmt = stmt.where(Chapter.is_new == True) # noqa: E712
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def get_chapters_with_sections(
user_id: str,
db: AsyncSession,
*,
active_only: bool = True,
is_new_only: bool | None = None,
) -> list[Chapter]:
"""兼容旧名:与 get_chapters_for_memoir_list 相同。"""
return await get_chapters_for_memoir_list(
user_id, db, active_only=active_only, is_new_only=is_new_only
)
async def get_chapter_by_id(chapter_id: str, db: AsyncSession) -> Chapter | None:
stmt = (
select(Chapter)
.where(Chapter.id == chapter_id)
.options(
joinedload(Chapter.images),
joinedload(Chapter.story_links).joinedload(ChapterStoryLink.story),
)
)
result = await db.execute(stmt)
return result.unique().scalars().one_or_none()
async def get_memoir_state(user_id: str, db: AsyncSession) -> MemoirState | None:
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
def get_archived_chapter_summaries_sync(
session: Session, user_id: str, category: str
) -> list[tuple[str, str]]:
"""获取已删除is_active=False的同类别章节的标题与内容摘要供 AI 参考。"""
stmt = (
select(Chapter)
.where(
Chapter.user_id == user_id,
Chapter.category == category,
Chapter.is_active == False, # noqa: E712
)
.order_by(Chapter.updated_at.desc())
)
result = session.execute(stmt)
chapters = list(result.unique().scalars().all())
summaries: list[tuple[str, str]] = []
for ch in chapters:
combined = (getattr(ch, "canonical_markdown", None) or "").strip()
preview = (combined[:200] + "...") if len(combined) > 200 else combined
if preview.strip():
summaries.append((ch.title or "", preview))
return summaries
def ensure_chapter_markdown_and_version_sync(
session: Session,
chapter: Chapter,
markdown: str,
) -> None:
"""
为已有 chapter 设置 canonical_markdown 并创建 chapter_version。
供非 story 物化路径(如 save_chapter_markdown_sync使用。
"""
from sqlalchemy import func
count_stmt = select(func.count(ChapterVersion.id)).where(
ChapterVersion.chapter_id == chapter.id
)
version_no = (session.execute(count_stmt).scalar() or 0) + 1
version = ChapterVersion(
id=str(uuid.uuid4()),
chapter_id=chapter.id,
version_no=version_no,
markdown_snapshot=markdown,
actor_type="ai",
source_type="generate",
)
session.add(version)
session.flush()
chapter.canonical_markdown = markdown
chapter.current_version_id = version.id
def save_chapter_markdown_sync(
session: Session,
*,
user_id: str,
chapter_id: str | None,
title: str,
category: str,
order_index: int,
markdown: str,
source_segments: list[str] | None = None,
) -> Chapter:
"""
将 markdown 写入 chapter.canonical_markdown 和 chapter_versions。
Agent 不直接调用,由 service/task 调用。
若 chapter_id 为 None 则新建章节。
"""
if chapter_id:
chapter = session.get(Chapter, chapter_id)
if not chapter or chapter.user_id != user_id:
raise ValueError(f"Chapter {chapter_id} not found or access denied")
else:
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
category=category,
order_index=order_index,
status="completed",
is_new=True,
is_active=True,
source_segments=source_segments or [],
)
session.add(chapter)
session.flush()
# 创建 chapter_version
from sqlalchemy import func
count_stmt = select(func.count(ChapterVersion.id)).where(
ChapterVersion.chapter_id == chapter.id
)
version_no = (session.execute(count_stmt).scalar() or 0) + 1
version = ChapterVersion(
id=str(uuid.uuid4()),
chapter_id=chapter.id,
version_no=version_no,
markdown_snapshot=markdown,
actor_type="ai",
source_type="generate",
)
session.add(version)
session.flush()
chapter.canonical_markdown = markdown
chapter.current_version_id = version.id
chapter.title = title
chapter.is_new = True
if source_segments:
chapter.source_segments = list(
set((chapter.source_segments or []) + source_segments)
)
session.flush()
session.refresh(chapter)
return chapter
async def count_chapter_story_links(db: AsyncSession, chapter_id: str) -> int:
stmt = (
select(func.count())
.select_from(ChapterStoryLink)
.where(ChapterStoryLink.chapter_id == chapter_id)
)
n = await db.scalar(stmt)
return int(n or 0)
async def get_chapter_ids_linked_to_story(db: AsyncSession, story_id: str) -> list[str]:
stmt = select(ChapterStoryLink.chapter_id).where(
ChapterStoryLink.story_id == story_id
)
result = await db.execute(stmt)
return list(dict.fromkeys(result.scalars().all()))
async def mark_chapters_dirty_for_story(db: AsyncSession, story_id: str) -> None:
ids = await get_chapter_ids_linked_to_story(db, story_id)
if not ids:
return
await db.execute(
update(Chapter).where(Chapter.id.in_(ids)).values(markdown_compose_dirty=True)
)
def mark_chapters_dirty_for_story_sync(session: Session, story_id: str) -> None:
stmt = select(ChapterStoryLink.chapter_id).where(
ChapterStoryLink.story_id == story_id
)
ids = list(dict.fromkeys(session.scalars(stmt).all()))
if not ids:
return
session.execute(
update(Chapter).where(Chapter.id.in_(ids)).values(markdown_compose_dirty=True)
)
async def get_chapter_with_story_links_for_compose(
chapter_id: str, db: AsyncSession
) -> Chapter | None:
stmt = (
select(Chapter)
.where(Chapter.id == chapter_id)
.options(
joinedload(Chapter.story_links).joinedload(ChapterStoryLink.story),
)
)
result = await db.execute(stmt)
return result.unique().scalar_one_or_none()
async def append_chapter_compose_version_async(
db: AsyncSession,
chapter: Chapter,
markdown: str,
) -> None:
count_stmt = select(func.count(ChapterVersion.id)).where(
ChapterVersion.chapter_id == chapter.id
)
version_no = (await db.execute(count_stmt)).scalar() or 0
version_no += 1
vid = str(uuid.uuid4())
version = ChapterVersion(
id=vid,
chapter_id=chapter.id,
version_no=version_no,
markdown_snapshot=markdown,
actor_type="system",
source_type="compose_from_stories",
)
db.add(version)
await db.flush()
chapter.canonical_markdown = markdown
chapter.current_version_id = vid
chapter.markdown_compose_dirty = False
chapter.markdown_composed_at = utc_now()
def append_chapter_compose_version_sync(
session: Session,
chapter: Chapter,
markdown: str,
) -> None:
count_stmt = select(func.count(ChapterVersion.id)).where(
ChapterVersion.chapter_id == chapter.id
)
version_no = (session.execute(count_stmt).scalar() or 0) + 1
vid = str(uuid.uuid4())
version = ChapterVersion(
id=vid,
chapter_id=chapter.id,
version_no=version_no,
markdown_snapshot=markdown,
actor_type="system",
source_type="compose_from_stories",
)
session.add(version)
session.flush()
chapter.canonical_markdown = markdown
chapter.current_version_id = vid
chapter.markdown_compose_dirty = False
chapter.markdown_composed_at = utc_now()
def compose_chapter_from_story_links_sync(session: Session, chapter_id: str) -> bool:
"""
按 story_links 重组 canonical_markdown 并写入版本链。
若无 story_links 则清除 dirty 并返回 False。
"""
stmt = (
select(Chapter)
.where(Chapter.id == chapter_id)
.options(
joinedload(Chapter.story_links).joinedload(ChapterStoryLink.story),
)
)
chapter = session.execute(stmt).unique().scalar_one_or_none()
if not chapter:
return False
links = list(chapter.story_links or [])
if not links:
chapter.markdown_compose_dirty = False
session.flush()
return False
md = materialize_chapter_markdown_from_loaded_chapter(chapter)
append_chapter_compose_version_sync(session, chapter, md)
return True
async def replace_chapter_story_links_async(
db: AsyncSession,
*,
chapter_id: str,
user_id: str,
story_ids: list[str],
) -> None:
chapter = await db.get(Chapter, chapter_id)
if not chapter or chapter.user_id != user_id:
raise ValueError("Chapter not found or access denied")
if len(story_ids) != len(set(story_ids)):
raise ValueError("Duplicate story_id in story_ids")
if story_ids:
stmt = select(Story.id).where(
Story.id.in_(story_ids),
Story.user_id == user_id,
)
result = await db.execute(stmt)
found = set(result.scalars().all())
missing = set(story_ids) - found
if missing:
raise ValueError(f"Stories not found or not owned: {sorted(missing)}")
await db.execute(
delete(ChapterStoryLink).where(ChapterStoryLink.chapter_id == chapter_id)
)
await db.flush()
for i, sid in enumerate(story_ids):
db.add(
ChapterStoryLink(
id=str(uuid.uuid4()),
chapter_id=chapter_id,
story_id=sid,
order_index=i,
)
)
await db.flush()
async def collect_cos_storage_keys_for_chapter(
db: AsyncSession, chapter: Chapter
) -> list[str]:
"""
章节内插图 MemoirImage、正文 asset:// 引用的 Asset、封面 cover_asset、封面意图绑定的 Asset 的 storage_key。
用于软删除章节后回收 COS 空间。
"""
keys: set[str] = set()
for img in getattr(chapter, "images", None) or []:
sk = getattr(img, "storage_key", None)
if sk:
keys.add(sk)
asset_ids = set(collect_asset_ids_for_chapter(chapter))
intent_rows = await db.execute(
select(ChapterCoverIntent.asset_id).where(
ChapterCoverIntent.chapter_id == chapter.id,
ChapterCoverIntent.asset_id.isnot(None),
)
)
for aid in intent_rows.scalars().all():
if aid:
asset_ids.add(str(aid))
if asset_ids:
row_keys = await db.execute(
select(Asset.storage_key).where(
Asset.id.in_(asset_ids),
Asset.storage_key.isnot(None),
)
)
keys.update(k for k in row_keys.scalars().all() if k)
return sorted(keys)