"""Story repository — Story, StoryVersion, StoryEvidenceLink data access.""" import uuid from datetime import datetime, timezone from sqlalchemy import delete, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.features.story.models import ( Story, StoryEvidenceLink, StoryImageIntent, StoryVersion, ) def _new_id() -> str: return str(uuid.uuid4()) async def create_story( db: AsyncSession, *, user_id: str, title: str, stage: str | None = None, story_type: str | None = None, summary: str | None = None, canonical_markdown: str | None = None, ) -> Story: """Create a story. Caller must commit.""" story = Story( id=_new_id(), user_id=user_id, title=title, stage=stage, story_type=story_type, summary=summary, canonical_markdown=canonical_markdown or "", ) db.add(story) return story async def create_story_version( db: AsyncSession, *, story_id: str, version_no: int, markdown_snapshot: str, actor_type: str = "ai", source_type: str = "generate", parent_version_id: str | None = None, prompt_meta: dict | None = None, ) -> StoryVersion: """Create a story version. Caller must commit.""" version = StoryVersion( id=_new_id(), story_id=story_id, version_no=version_no, markdown_snapshot=markdown_snapshot, actor_type=actor_type, source_type=source_type, parent_version_id=parent_version_id, prompt_meta=prompt_meta, ) db.add(version) return version async def create_story_evidence_link( db: AsyncSession, *, story_id: str, evidence_type: str, evidence_id: str, role: str = "primary", weight: float | None = None, ) -> StoryEvidenceLink: """Create story-evidence link. Caller must commit.""" link = StoryEvidenceLink( id=_new_id(), story_id=story_id, evidence_type=evidence_type, evidence_id=evidence_id, role=role, weight=weight, ) db.add(link) return link async def get_story_by_id(db: AsyncSession, story_id: str) -> Story | None: """Fetch story by ID.""" return await db.get(Story, story_id) async def get_stories_for_user( db: AsyncSession, user_id: str, *, status: str | None = "active" ) -> list[Story]: """Fetch stories for user, optionally filtered by status.""" stmt = select(Story).where(Story.user_id == user_id) if status: stmt = stmt.where(Story.status == status) stmt = stmt.order_by(Story.created_at.desc()) result = await db.execute(stmt) return list(result.unique().scalars().all()) async def count_story_versions(db: AsyncSession, story_id: str) -> int: """Count versions for a story.""" from sqlalchemy import func stmt = select(func.count(StoryVersion.id)).where(StoryVersion.story_id == story_id) result = await db.execute(stmt) return result.scalar() or 0 async def create_story_image_intent( db: AsyncSession, *, story_id: str, story_version_id: str | None, caption: str, prompt_brief: str, style_profile: str | None = None, ) -> StoryImageIntent: """Create primary image intent for a story. Caller must commit.""" intent = StoryImageIntent( id=_new_id(), story_id=story_id, story_version_id=story_version_id, intent_role="primary", caption=caption, prompt_brief=prompt_brief, style_profile=style_profile, status="pending", ) db.add(intent) return intent async def get_story_image_intent_by_story( db: AsyncSession, story_id: str, *, role: str = "primary" ) -> StoryImageIntent | None: """Get primary image intent for a story.""" stmt = ( select(StoryImageIntent) .where(StoryImageIntent.story_id == story_id) .where(StoryImageIntent.intent_role == role) ) result = await db.execute(stmt) return result.unique().scalar_one_or_none() async def delete_story_image_intents_by_story( db: AsyncSession, story_id: str, *, role: str = "primary", statuses: list[str] | None = None, ) -> int: """ 删除指定 story 的配图 intent。 statuses 为 None 时删除该 role 下全部;否则仅删除列出的状态(如仅清 pending/failed,避免打断 processing)。 """ stmt = delete(StoryImageIntent).where( StoryImageIntent.story_id == story_id, StoryImageIntent.intent_role == role, ) if statuses is not None: stmt = stmt.where(StoryImageIntent.status.in_(statuses)) result = await db.execute(stmt) return result.rowcount or 0 async def get_stories_by_ids(db: AsyncSession, story_ids: list[str]) -> list[Story]: """Fetch stories by IDs.""" if not story_ids: return [] stmt = select(Story).where(Story.id.in_(story_ids)) result = await db.execute(stmt) stories = list(result.unique().scalars().all()) order = {sid: i for i, sid in enumerate(story_ids)} return sorted(stories, key=lambda s: order.get(s.id, 999)) async def list_recent_stories_for_evidence( db: AsyncSession, user_id: str, *, query: str | None = None, limit: int = 5, ) -> list[Story]: """供 memory 检索:活跃故事,可选标题/摘要模糊匹配。""" stmt = select(Story).where(Story.user_id == user_id).where(Story.status == "active") q = (query or "").strip() if q: pat = f"%{q}%" stmt = stmt.where(or_(Story.title.ilike(pat), Story.summary.ilike(pat))) stmt = stmt.order_by(Story.updated_at.desc()).limit(limit) result = await db.execute(stmt) return list(result.unique().scalars().all()) def list_recent_stories_for_evidence_sync( session: Session, user_id: str, *, query: str | None = None, limit: int = 5, ) -> list[Story]: """同步会话版 `list_recent_stories_for_evidence`(Celery / retrieve_evidence_sync)。""" stmt = select(Story).where(Story.user_id == user_id).where(Story.status == "active") q = (query or "").strip() if q: pat = f"%{q}%" stmt = stmt.where(or_(Story.title.ilike(pat), Story.summary.ilike(pat))) stmt = stmt.order_by(Story.updated_at.desc()).limit(limit) result = session.execute(stmt) return list(result.unique().scalars().all())