"""评测台:查询用户 segment 是否已完成回忆录 Phase1(topic_category 已写入)。""" from __future__ import annotations from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.features.conversation.models import Conversation, Segment from app.features.conversation.ws.pipeline import background_runner from app.features.evaluation.errors import ( EvaluationBadRequestError, EvaluationNotFoundError, ) from app.features.evaluation.schemas import MemoirPhase1ReadyOut, MemoirSubmitOut class MemoirReadinessService: def __init__(self, db: AsyncSession) -> None: self._db = db async def memoir_phase1_ready_for_segments( self, *, conversation_id: str, segment_ids: list[str], ) -> MemoirPhase1ReadyOut: cid = (conversation_id or "").strip() if not cid: raise EvaluationBadRequestError("conversation_id is required") ids = [s.strip() for s in segment_ids if (s or "").strip()] if not ids: raise EvaluationBadRequestError("segment_ids is required") conv = await self._db.get(Conversation, cid) if not conv or conv.deleted_at is not None: raise EvaluationNotFoundError("conversation not found") stmt = select(Segment).where( Segment.id.in_(ids), Segment.conversation_id == cid, ) result = await self._db.execute(stmt) rows = list(result.scalars().all()) found_ids = {s.id for s in rows} missing = [i for i in ids if i not in found_ids] if missing: raise EvaluationBadRequestError( "segment not in conversation: " + ", ".join(missing[:5]) + ("…" if len(missing) > 5 else "") ) pending = [s.id for s in rows if s.topic_category is None] ready = len(pending) == 0 return MemoirPhase1ReadyOut( ready=ready, checked_segment_ids=ids, pending_segment_ids=pending, ) async def submit_memoir_phase1_for_conversation( self, *, conversation_id: str, ) -> MemoirSubmitOut: """本会话内尚待 Phase1 的 segment 合并提交 Celery(与对话结束 flush 语义对齐)。""" cid = (conversation_id or "").strip() if not cid: raise EvaluationBadRequestError("conversation_id is required") conv = await self._db.get(Conversation, cid) if not conv or conv.deleted_at is not None: raise EvaluationNotFoundError("conversation not found") uid = str(conv.user_id) stmt = ( select(Segment.id) .where( Segment.conversation_id == cid, Segment.processed.is_(False), Segment.topic_category.is_(None), ) .order_by(Segment.created_at.asc()) ) result = await self._db.execute(stmt) segment_ids = [str(i) for i in result.scalars().all()] if not segment_ids: return MemoirSubmitOut( conversation_id=cid, user_id=uid, segment_ids=[], celery_task_id=None, ) task_id = await background_runner.flush_pending( uid, extra_segment_ids=segment_ids ) return MemoirSubmitOut( conversation_id=cid, user_id=uid, segment_ids=segment_ids, celery_task_id=task_id, )