"""只读访问生产对话表,供内部浏览与快照。""" from __future__ import annotations from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.features.conversation.models import Conversation, ConversationMessage, Segment class SessionCatalogRepo: def __init__(self, db: AsyncSession) -> None: self._db = db async def count_conversations(self, *, status: str | None = None) -> int: q = ( select(func.count()) .select_from(Conversation) .where(Conversation.deleted_at.is_(None)) ) if status: q = q.where(Conversation.status == status) r = await self._db.execute(q) return int(r.scalar() or 0) async def list_conversations( self, *, offset: int = 0, limit: int = 50, user_id: str | None = None, q_text: str | None = None, status: str | None = None, ) -> list[Conversation]: stmt = select(Conversation).where(Conversation.deleted_at.is_(None)) if user_id: stmt = stmt.where(Conversation.user_id == user_id) if status: stmt = stmt.where(Conversation.status == status) if status == "active": stmt = stmt.order_by( Conversation.last_message_at.desc().nullslast(), Conversation.started_at.desc().nullslast(), ) else: stmt = stmt.order_by(Conversation.started_at.desc().nullslast()) stmt = stmt.offset(offset).limit(limit) # q_text: 简单按 topic 搜索(后续可扩展全文) if q_text: like = f"%{q_text.strip()}%" stmt = stmt.where( (Conversation.current_topic.isnot(None)) & (Conversation.current_topic.ilike(like)) ) stmt = stmt.options(joinedload(Conversation.user)) res = await self._db.execute(stmt) return list(res.scalars().unique().all()) async def get_conversation(self, conversation_id: str) -> Conversation | None: return await self._db.get(Conversation, conversation_id) async def list_segments_for_conversation( self, conversation_id: str ) -> list[Segment]: stmt = ( select(Segment) .where(Segment.conversation_id == conversation_id) .order_by(Segment.created_at.asc()) ) res = await self._db.execute(stmt) return list(res.scalars().all()) async def list_messages_for_conversation( self, conversation_id: str ) -> list[ConversationMessage]: stmt = ( select(ConversationMessage) .where(ConversationMessage.conversation_id == conversation_id) .order_by(ConversationMessage.created_at.asc()) ) res = await self._db.execute(stmt) return list(res.scalars().all())