"""Conversation repository — Conversation, Segment data access.""" from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.features.conversation.models import Conversation, Segment async def get_conversation(conversation_id: str, db: AsyncSession) -> Conversation | None: return await db.get(Conversation, conversation_id) async def get_user_conversations(user_id: str, db: AsyncSession) -> list[Conversation]: stmt = ( select(Conversation) .where(Conversation.user_id == user_id) .order_by(func.coalesce(Conversation.last_message_at, Conversation.started_at).desc()) ) result = await db.execute(stmt) return list(result.scalars().all()) def add_conversation(conv: Conversation, db: AsyncSession) -> None: db.add(conv) async def get_segments_for_conversation(conversation_id: str, db: AsyncSession) -> list[Segment]: stmt = ( select(Segment) .where(Segment.conversation_id == conversation_id) .order_by(Segment.created_at) ) result = await db.execute(stmt) return list(result.scalars().all()) async def get_segments_for_organize(conversation_id: str, db: AsyncSession) -> list[Segment]: """Unprocessed segments first; if none, all segments.""" stmt = ( select(Segment) .where( Segment.conversation_id == conversation_id, Segment.processed.is_(False), ) .order_by(Segment.created_at) ) result = await db.execute(stmt) segments = list(result.scalars().all()) if not segments: return await get_segments_for_conversation(conversation_id, db) return segments async def count_segments_for_user(user_id: str, db: AsyncSession) -> int: stmt = ( select(func.count(Segment.id)) .select_from(Segment) .join(Conversation, Segment.conversation_id == Conversation.id) .where(Conversation.user_id == user_id) ) result = await db.execute(stmt) return result.scalar() or 0