186 lines
5.0 KiB
Python
186 lines
5.0 KiB
Python
|
|
"""Story repository — Story, StoryVersion, StoryEvidenceLink data access."""
|
|||
|
|
|
|||
|
|
import uuid
|
|||
|
|
from datetime import datetime, timezone
|
|||
|
|
|
|||
|
|
from sqlalchemy import delete, select
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
|
|||
|
|
from app.features.story.models import (
|
|||
|
|
Story,
|
|||
|
|
StoryEvidenceLink,
|
|||
|
|
StoryImageIntent,
|
|||
|
|
StoryVersion,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _new_id() -> str:
|
|||
|
|
return str(uuid.uuid4())
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def create_story(
|
|||
|
|
db: AsyncSession,
|
|||
|
|
*,
|
|||
|
|
user_id: str,
|
|||
|
|
title: str,
|
|||
|
|
stage: str | None = None,
|
|||
|
|
story_type: str | None = None,
|
|||
|
|
summary: str | None = None,
|
|||
|
|
canonical_markdown: str | None = None,
|
|||
|
|
) -> Story:
|
|||
|
|
"""Create a story. Caller must commit."""
|
|||
|
|
story = Story(
|
|||
|
|
id=_new_id(),
|
|||
|
|
user_id=user_id,
|
|||
|
|
title=title,
|
|||
|
|
stage=stage,
|
|||
|
|
story_type=story_type,
|
|||
|
|
summary=summary,
|
|||
|
|
canonical_markdown=canonical_markdown or "",
|
|||
|
|
)
|
|||
|
|
db.add(story)
|
|||
|
|
return story
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def create_story_version(
|
|||
|
|
db: AsyncSession,
|
|||
|
|
*,
|
|||
|
|
story_id: str,
|
|||
|
|
version_no: int,
|
|||
|
|
markdown_snapshot: str,
|
|||
|
|
actor_type: str = "ai",
|
|||
|
|
source_type: str = "generate",
|
|||
|
|
parent_version_id: str | None = None,
|
|||
|
|
prompt_meta: dict | None = None,
|
|||
|
|
) -> StoryVersion:
|
|||
|
|
"""Create a story version. Caller must commit."""
|
|||
|
|
version = StoryVersion(
|
|||
|
|
id=_new_id(),
|
|||
|
|
story_id=story_id,
|
|||
|
|
version_no=version_no,
|
|||
|
|
markdown_snapshot=markdown_snapshot,
|
|||
|
|
actor_type=actor_type,
|
|||
|
|
source_type=source_type,
|
|||
|
|
parent_version_id=parent_version_id,
|
|||
|
|
prompt_meta=prompt_meta,
|
|||
|
|
)
|
|||
|
|
db.add(version)
|
|||
|
|
return version
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def create_story_evidence_link(
|
|||
|
|
db: AsyncSession,
|
|||
|
|
*,
|
|||
|
|
story_id: str,
|
|||
|
|
evidence_type: str,
|
|||
|
|
evidence_id: str,
|
|||
|
|
role: str = "primary",
|
|||
|
|
weight: float | None = None,
|
|||
|
|
) -> StoryEvidenceLink:
|
|||
|
|
"""Create story-evidence link. Caller must commit."""
|
|||
|
|
link = StoryEvidenceLink(
|
|||
|
|
id=_new_id(),
|
|||
|
|
story_id=story_id,
|
|||
|
|
evidence_type=evidence_type,
|
|||
|
|
evidence_id=evidence_id,
|
|||
|
|
role=role,
|
|||
|
|
weight=weight,
|
|||
|
|
)
|
|||
|
|
db.add(link)
|
|||
|
|
return link
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_story_by_id(db: AsyncSession, story_id: str) -> Story | None:
|
|||
|
|
"""Fetch story by ID."""
|
|||
|
|
return await db.get(Story, story_id)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_stories_for_user(
|
|||
|
|
db: AsyncSession, user_id: str, *, status: str | None = "active"
|
|||
|
|
) -> list[Story]:
|
|||
|
|
"""Fetch stories for user, optionally filtered by status."""
|
|||
|
|
stmt = select(Story).where(Story.user_id == user_id)
|
|||
|
|
if status:
|
|||
|
|
stmt = stmt.where(Story.status == status)
|
|||
|
|
stmt = stmt.order_by(Story.created_at.desc())
|
|||
|
|
result = await db.execute(stmt)
|
|||
|
|
return list(result.unique().scalars().all())
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def count_story_versions(db: AsyncSession, story_id: str) -> int:
|
|||
|
|
"""Count versions for a story."""
|
|||
|
|
from sqlalchemy import func
|
|||
|
|
|
|||
|
|
stmt = select(func.count(StoryVersion.id)).where(StoryVersion.story_id == story_id)
|
|||
|
|
result = await db.execute(stmt)
|
|||
|
|
return result.scalar() or 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def create_story_image_intent(
|
|||
|
|
db: AsyncSession,
|
|||
|
|
*,
|
|||
|
|
story_id: str,
|
|||
|
|
story_version_id: str | None,
|
|||
|
|
caption: str,
|
|||
|
|
prompt_brief: str,
|
|||
|
|
style_profile: str | None = None,
|
|||
|
|
) -> StoryImageIntent:
|
|||
|
|
"""Create primary image intent for a story. Caller must commit."""
|
|||
|
|
intent = StoryImageIntent(
|
|||
|
|
id=_new_id(),
|
|||
|
|
story_id=story_id,
|
|||
|
|
story_version_id=story_version_id,
|
|||
|
|
intent_role="primary",
|
|||
|
|
caption=caption,
|
|||
|
|
prompt_brief=prompt_brief,
|
|||
|
|
style_profile=style_profile,
|
|||
|
|
status="pending",
|
|||
|
|
)
|
|||
|
|
db.add(intent)
|
|||
|
|
return intent
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_story_image_intent_by_story(
|
|||
|
|
db: AsyncSession, story_id: str, *, role: str = "primary"
|
|||
|
|
) -> StoryImageIntent | None:
|
|||
|
|
"""Get primary image intent for a story."""
|
|||
|
|
stmt = (
|
|||
|
|
select(StoryImageIntent)
|
|||
|
|
.where(StoryImageIntent.story_id == story_id)
|
|||
|
|
.where(StoryImageIntent.intent_role == role)
|
|||
|
|
)
|
|||
|
|
result = await db.execute(stmt)
|
|||
|
|
return result.unique().scalar_one_or_none()
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def delete_story_image_intents_by_story(
|
|||
|
|
db: AsyncSession,
|
|||
|
|
story_id: str,
|
|||
|
|
*,
|
|||
|
|
role: str = "primary",
|
|||
|
|
statuses: list[str] | None = None,
|
|||
|
|
) -> int:
|
|||
|
|
"""
|
|||
|
|
删除指定 story 的配图 intent。
|
|||
|
|
statuses 为 None 时删除该 role 下全部;否则仅删除列出的状态(如仅清 pending/failed,避免打断 processing)。
|
|||
|
|
"""
|
|||
|
|
stmt = delete(StoryImageIntent).where(
|
|||
|
|
StoryImageIntent.story_id == story_id,
|
|||
|
|
StoryImageIntent.intent_role == role,
|
|||
|
|
)
|
|||
|
|
if statuses is not None:
|
|||
|
|
stmt = stmt.where(StoryImageIntent.status.in_(statuses))
|
|||
|
|
result = await db.execute(stmt)
|
|||
|
|
return result.rowcount or 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_stories_by_ids(db: AsyncSession, story_ids: list[str]) -> list[Story]:
|
|||
|
|
"""Fetch stories by IDs."""
|
|||
|
|
if not story_ids:
|
|||
|
|
return []
|
|||
|
|
stmt = select(Story).where(Story.id.in_(story_ids))
|
|||
|
|
result = await db.execute(stmt)
|
|||
|
|
stories = list(result.unique().scalars().all())
|
|||
|
|
order = {sid: i for i, sid in enumerate(story_ids)}
|
|||
|
|
return sorted(stories, key=lambda s: order.get(s.id, 999))
|