- 访谈:新增 interview_state_hints,联动 orchestrator 与提示词 - 回忆录:story_pipeline_sync/state/memory/post_commit 与 Celery 任务调整 - 基建:开发用 celery broker、compose/development 脚本、依赖注入 - eval-web:移除数据集/实验/版本等页面与流式轮询,突出 Playground - 文档与单测同步
100 lines
3.5 KiB
Python
100 lines
3.5 KiB
Python
"""评测台:查询用户 segment 是否已完成回忆录 Phase1(topic_category 已写入)。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.features.conversation.models import Conversation, Segment
|
||
from app.features.conversation.ws.pipeline import background_runner
|
||
from app.features.evaluation.errors import (
|
||
EvaluationBadRequestError,
|
||
EvaluationNotFoundError,
|
||
)
|
||
from app.features.evaluation.schemas import MemoirPhase1ReadyOut, MemoirSubmitOut
|
||
|
||
|
||
class MemoirReadinessService:
|
||
def __init__(self, db: AsyncSession) -> None:
|
||
self._db = db
|
||
|
||
async def memoir_phase1_ready_for_segments(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
segment_ids: list[str],
|
||
) -> MemoirPhase1ReadyOut:
|
||
cid = (conversation_id or "").strip()
|
||
if not cid:
|
||
raise EvaluationBadRequestError("conversation_id is required")
|
||
ids = [s.strip() for s in segment_ids if (s or "").strip()]
|
||
if not ids:
|
||
raise EvaluationBadRequestError("segment_ids is required")
|
||
|
||
conv = await self._db.get(Conversation, cid)
|
||
if not conv or conv.deleted_at is not None:
|
||
raise EvaluationNotFoundError("conversation not found")
|
||
|
||
stmt = select(Segment).where(
|
||
Segment.id.in_(ids),
|
||
Segment.conversation_id == cid,
|
||
)
|
||
result = await self._db.execute(stmt)
|
||
rows = list(result.scalars().all())
|
||
found_ids = {s.id for s in rows}
|
||
missing = [i for i in ids if i not in found_ids]
|
||
if missing:
|
||
raise EvaluationBadRequestError(
|
||
"segment not in conversation: " + ", ".join(missing[:5])
|
||
+ ("…" if len(missing) > 5 else "")
|
||
)
|
||
|
||
pending = [s.id for s in rows if s.topic_category is None]
|
||
ready = len(pending) == 0
|
||
return MemoirPhase1ReadyOut(
|
||
ready=ready,
|
||
checked_segment_ids=ids,
|
||
pending_segment_ids=pending,
|
||
)
|
||
|
||
async def submit_memoir_phase1_for_conversation(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
) -> MemoirSubmitOut:
|
||
"""本会话内尚待 Phase1 的 segment 合并提交 Celery(与对话结束 flush 语义对齐)。"""
|
||
cid = (conversation_id or "").strip()
|
||
if not cid:
|
||
raise EvaluationBadRequestError("conversation_id is required")
|
||
conv = await self._db.get(Conversation, cid)
|
||
if not conv or conv.deleted_at is not None:
|
||
raise EvaluationNotFoundError("conversation not found")
|
||
uid = str(conv.user_id)
|
||
stmt = (
|
||
select(Segment.id)
|
||
.where(
|
||
Segment.conversation_id == cid,
|
||
Segment.processed.is_(False),
|
||
Segment.topic_category.is_(None),
|
||
)
|
||
.order_by(Segment.created_at.asc())
|
||
)
|
||
result = await self._db.execute(stmt)
|
||
segment_ids = [str(i) for i in result.scalars().all()]
|
||
if not segment_ids:
|
||
return MemoirSubmitOut(
|
||
conversation_id=cid,
|
||
user_id=uid,
|
||
segment_ids=[],
|
||
celery_task_id=None,
|
||
)
|
||
task_id = await background_runner.flush_pending(
|
||
uid, extra_segment_ids=segment_ids
|
||
)
|
||
return MemoirSubmitOut(
|
||
conversation_id=cid,
|
||
user_id=uid,
|
||
segment_ids=segment_ids,
|
||
celery_task_id=task_id,
|
||
)
|