Files
life-echo/api/app/features/evaluation/memoir_readiness_service.py
Kevin 064ad2161d refactor(eval+memoir):精简内部评测路由与服务,composite/对话摘要与 judge 能力补强
- 访谈:新增 interview_state_hints,联动 orchestrator 与提示词
- 回忆录:story_pipeline_sync/state/memory/post_commit 与 Celery 任务调整
- 基建:开发用 celery broker、compose/development 脚本、依赖注入
- eval-web:移除数据集/实验/版本等页面与流式轮询,突出 Playground
- 文档与单测同步
2026-04-08 21:36:12 +08:00

100 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""评测台:查询用户 segment 是否已完成回忆录 Phase1topic_category 已写入)。"""
from __future__ import annotations
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.ws.pipeline import background_runner
from app.features.evaluation.errors import (
EvaluationBadRequestError,
EvaluationNotFoundError,
)
from app.features.evaluation.schemas import MemoirPhase1ReadyOut, MemoirSubmitOut
class MemoirReadinessService:
def __init__(self, db: AsyncSession) -> None:
self._db = db
async def memoir_phase1_ready_for_segments(
self,
*,
conversation_id: str,
segment_ids: list[str],
) -> MemoirPhase1ReadyOut:
cid = (conversation_id or "").strip()
if not cid:
raise EvaluationBadRequestError("conversation_id is required")
ids = [s.strip() for s in segment_ids if (s or "").strip()]
if not ids:
raise EvaluationBadRequestError("segment_ids is required")
conv = await self._db.get(Conversation, cid)
if not conv or conv.deleted_at is not None:
raise EvaluationNotFoundError("conversation not found")
stmt = select(Segment).where(
Segment.id.in_(ids),
Segment.conversation_id == cid,
)
result = await self._db.execute(stmt)
rows = list(result.scalars().all())
found_ids = {s.id for s in rows}
missing = [i for i in ids if i not in found_ids]
if missing:
raise EvaluationBadRequestError(
"segment not in conversation: " + ", ".join(missing[:5])
+ ("" if len(missing) > 5 else "")
)
pending = [s.id for s in rows if s.topic_category is None]
ready = len(pending) == 0
return MemoirPhase1ReadyOut(
ready=ready,
checked_segment_ids=ids,
pending_segment_ids=pending,
)
async def submit_memoir_phase1_for_conversation(
self,
*,
conversation_id: str,
) -> MemoirSubmitOut:
"""本会话内尚待 Phase1 的 segment 合并提交 Celery与对话结束 flush 语义对齐)。"""
cid = (conversation_id or "").strip()
if not cid:
raise EvaluationBadRequestError("conversation_id is required")
conv = await self._db.get(Conversation, cid)
if not conv or conv.deleted_at is not None:
raise EvaluationNotFoundError("conversation not found")
uid = str(conv.user_id)
stmt = (
select(Segment.id)
.where(
Segment.conversation_id == cid,
Segment.processed.is_(False),
Segment.topic_category.is_(None),
)
.order_by(Segment.created_at.asc())
)
result = await self._db.execute(stmt)
segment_ids = [str(i) for i in result.scalars().all()]
if not segment_ids:
return MemoirSubmitOut(
conversation_id=cid,
user_id=uid,
segment_ids=[],
celery_task_id=None,
)
task_id = await background_runner.flush_pending(
uid, extra_segment_ids=segment_ids
)
return MemoirSubmitOut(
conversation_id=cid,
user_id=uid,
segment_ids=segment_ids,
celery_task_id=task_id,
)