"""Conversation repository — Conversation, turn log, and Segment data access.""" from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.features.conversation.models import Conversation, ConversationMessage, Segment async def get_conversation( conversation_id: str, db: AsyncSession ) -> Conversation | None: return await db.get(Conversation, conversation_id) async def get_user_conversations(user_id: str, db: AsyncSession) -> list[Conversation]: stmt = ( select(Conversation) .where( Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by( func.coalesce(Conversation.last_message_at, Conversation.started_at).desc() ) ) result = await db.execute(stmt) return list(result.scalars().all()) def add_conversation(conv: Conversation, db: AsyncSession) -> None: db.add(conv) def add_conversation_message(msg: ConversationMessage, db: AsyncSession) -> None: db.add(msg) async def get_conversation_messages( conversation_id: str, db: AsyncSession ) -> list[ConversationMessage]: stmt = ( select(ConversationMessage) .where(ConversationMessage.conversation_id == conversation_id) .order_by(ConversationMessage.created_at) ) result = await db.execute(stmt) return list(result.scalars().all()) async def set_latest_ai_message_tts_audio_urls( conversation_id: str, db: AsyncSession, *, tts_audio_urls: list[str], segment_id: str | None = None, ) -> ConversationMessage | None: stmt = select(ConversationMessage).where( ConversationMessage.conversation_id == conversation_id, ConversationMessage.role == "ai", ) if segment_id is not None: stmt = stmt.where(ConversationMessage.segment_id == segment_id) stmt = stmt.order_by(ConversationMessage.created_at.desc()) result = await db.execute(stmt) row = result.scalars().first() if row is None: return None row.tts_audio_urls = list(tts_audio_urls) return row async def get_segments_for_conversation( conversation_id: str, db: AsyncSession ) -> list[Segment]: stmt = ( select(Segment) .where(Segment.conversation_id == conversation_id) .order_by(Segment.created_at) ) result = await db.execute(stmt) return list(result.scalars().all()) async def get_segments_for_organize( conversation_id: str, db: AsyncSession ) -> list[Segment]: """兼容旧语义:优先返回 Phase1 未完成的片段;若无则返回本会话全部片段。""" pending = await get_segments_pending_phase1(conversation_id, db) if pending: return pending return await get_segments_for_conversation(conversation_id, db) async def get_segments_pending_phase1( conversation_id: str, db: AsyncSession ) -> list[Segment]: """尚未跑 Phase1 分类的 segments(topic_category 为空且未标记 narrated)。""" stmt = ( select(Segment) .where( Segment.conversation_id == conversation_id, Segment.topic_category.is_(None), Segment.narrated.is_(False), Segment.processed.is_(False), ) .order_by(Segment.created_at) ) result = await db.execute(stmt) return list(result.scalars().all()) async def conversation_has_pending_phase2( conversation_id: str, db: AsyncSession ) -> bool: """Phase1 已完成但叙事未消费的片段(本会话范围内)。""" stmt = select(func.count(Segment.id)).where( Segment.conversation_id == conversation_id, Segment.topic_category.isnot(None), Segment.narrated.is_(False), Segment.skip_narrative.is_(False), ) result = await db.execute(stmt) return int(result.scalar() or 0) > 0 async def count_segments_for_user(user_id: str, db: AsyncSession) -> int: stmt = ( select(func.count(Segment.id)) .select_from(Segment) .join(Conversation, Segment.conversation_id == Conversation.id) .where( Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) ) result = await db.execute(stmt) return result.scalar() or 0