"""Memoir repository — Book, Chapter, MemoirState data access.""" import uuid from datetime import datetime, timezone from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, joinedload from app.features.memoir.models import ( Book, Chapter, ChapterSection, ChapterVersion, MemoirState, ) async def get_current_book(user_id: str, db: AsyncSession) -> Book | None: stmt = ( select(Book) .where(Book.user_id == user_id) .order_by(Book.updated_at.desc()) .limit(1) ) result = await db.execute(stmt) return result.scalar_one_or_none() async def get_chapters_with_sections( user_id: str, db: AsyncSession, *, active_only: bool = True, is_new_only: bool | None = None, ) -> list[Chapter]: stmt = ( select(Chapter) .where(Chapter.user_id == user_id) .options( joinedload(Chapter.sections), joinedload(Chapter.images), joinedload(Chapter.sections).joinedload(ChapterSection.image_record), ) .order_by(Chapter.order_index) ) if active_only: stmt = stmt.where(Chapter.is_active == True) # noqa: E712 if is_new_only is True: stmt = stmt.where(Chapter.is_new == True) # noqa: E712 result = await db.execute(stmt) return list(result.unique().scalars().all()) async def get_chapter_by_id(chapter_id: str, db: AsyncSession) -> Chapter | None: stmt = ( select(Chapter) .where(Chapter.id == chapter_id) .options( joinedload(Chapter.sections), joinedload(Chapter.images), joinedload(Chapter.sections).joinedload(ChapterSection.image_record), ) ) result = await db.execute(stmt) return result.unique().scalars().one_or_none() async def get_memoir_state(user_id: str, db: AsyncSession) -> MemoirState | None: stmt = select(MemoirState).where(MemoirState.user_id == user_id) result = await db.execute(stmt) return result.scalar_one_or_none() def get_archived_chapter_summaries_sync( session: Session, user_id: str, category: str ) -> list[tuple[str, str]]: """获取已删除(is_active=False)的同类别章节的标题与内容摘要,供 AI 参考。""" stmt = ( select(Chapter) .where( Chapter.user_id == user_id, Chapter.category == category, Chapter.is_active == False, # noqa: E712 ) .options(joinedload(Chapter.sections)) .order_by(Chapter.updated_at.desc()) ) result = session.execute(stmt) chapters = list(result.unique().scalars().all()) summaries: list[tuple[str, str]] = [] for ch in chapters: sections = getattr(ch, "sections", None) or [] parts = [ (s.content or "").strip() for s in sorted(sections, key=lambda x: getattr(x, "order_index", 0)) ] combined = "".join(parts) preview = (combined[:200] + "...") if len(combined) > 200 else combined if preview.strip(): summaries.append((ch.title or "", preview)) return summaries def ensure_chapter_markdown_and_version_sync( session: Session, chapter: Chapter, markdown: str, ) -> None: """ 为已有 chapter 设置 canonical_markdown 并创建 chapter_version。 由 _save_narrative_to_sections 调用,确保 markdown 真源与版本链。 """ from sqlalchemy import func count_stmt = select(func.count(ChapterVersion.id)).where( ChapterVersion.chapter_id == chapter.id ) version_no = (session.execute(count_stmt).scalar() or 0) + 1 version = ChapterVersion( id=str(uuid.uuid4()), chapter_id=chapter.id, version_no=version_no, markdown_snapshot=markdown, actor_type="ai", source_type="generate", ) session.add(version) session.flush() chapter.canonical_markdown = markdown chapter.current_version_id = version.id def save_chapter_markdown_sync( session: Session, *, user_id: str, chapter_id: str | None, title: str, category: str, order_index: int, markdown: str, source_segments: list[str] | None = None, ) -> Chapter: """ 将 markdown 写入 chapter.canonical_markdown 和 chapter_versions。 Agent 不直接调用,由 service/task 调用。 若 chapter_id 为 None 则新建章节。 """ if chapter_id: chapter = session.get(Chapter, chapter_id) if not chapter or chapter.user_id != user_id: raise ValueError(f"Chapter {chapter_id} not found or access denied") else: chapter = Chapter( id=str(uuid.uuid4()), user_id=user_id, title=title, category=category, order_index=order_index, status="completed", is_new=True, is_active=True, source_segments=source_segments or [], ) session.add(chapter) session.flush() # 创建 chapter_version from sqlalchemy import func count_stmt = select(func.count(ChapterVersion.id)).where( ChapterVersion.chapter_id == chapter.id ) version_no = (session.execute(count_stmt).scalar() or 0) + 1 version = ChapterVersion( id=str(uuid.uuid4()), chapter_id=chapter.id, version_no=version_no, markdown_snapshot=markdown, actor_type="ai", source_type="generate", ) session.add(version) session.flush() chapter.canonical_markdown = markdown chapter.current_version_id = version.id chapter.title = title chapter.is_new = True if source_segments: chapter.source_segments = list( set((chapter.source_segments or []) + source_segments) ) session.flush() session.refresh(chapter) return chapter