Files
life-echo/api/app/features/evaluation/memoir_readiness_service.py

146 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""评测台:查询用户 segment 是否已完成回忆录 Phase1topic_category 已写入)。"""
from __future__ import annotations
import time
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.db import utc_now
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.ws.pipeline import background_runner
from app.features.evaluation.errors import (
EvaluationBadRequestError,
EvaluationNotFoundError,
)
from app.features.evaluation.phase1_job_timing import (
load_phase1_job_meta,
record_phase1_job_submitted,
)
from app.features.evaluation.schemas import MemoirPhase1ReadyOut, MemoirSubmitOut
class MemoirReadinessService:
def __init__(self, db: AsyncSession) -> None:
self._db = db
async def memoir_phase1_ready_for_segments(
self,
*,
conversation_id: str,
segment_ids: list[str],
) -> MemoirPhase1ReadyOut:
cid = (conversation_id or "").strip()
if not cid:
raise EvaluationBadRequestError("conversation_id is required")
ids = [s.strip() for s in segment_ids if (s or "").strip()]
if not ids:
raise EvaluationBadRequestError("segment_ids is required")
conv = await self._db.get(Conversation, cid)
if not conv or conv.deleted_at is not None:
raise EvaluationNotFoundError("conversation not found")
stmt = select(Segment).where(
Segment.id.in_(ids),
Segment.conversation_id == cid,
)
result = await self._db.execute(stmt)
rows = list(result.scalars().all())
found_ids = {s.id for s in rows}
missing = [i for i in ids if i not in found_ids]
if missing:
raise EvaluationBadRequestError(
"segment not in conversation: "
+ ", ".join(missing[:5])
+ ("" if len(missing) > 5 else "")
)
pending = [s.id for s in rows if s.topic_category is None]
ready = len(pending) == 0
job_submitted_at_utc: datetime | None = None
elapsed_ms_since_submit: int | None = None
durations_ms: dict[str, int] = {}
meta = await load_phase1_job_meta(cid)
if meta:
raw_sub = meta.get("submitted_at_utc")
if isinstance(raw_sub, str) and raw_sub.strip():
try:
iso = raw_sub.strip().replace("Z", "+00:00")
job_submitted_at_utc = datetime.fromisoformat(iso)
if job_submitted_at_utc.tzinfo is None:
job_submitted_at_utc = job_submitted_at_utc.replace(
tzinfo=timezone.utc
)
now = utc_now()
elapsed_ms_since_submit = max(
0,
int((now - job_submitted_at_utc).total_seconds() * 1000),
)
durations_ms["since_playground_submit"] = elapsed_ms_since_submit
except ValueError:
job_submitted_at_utc = None
elapsed_ms_since_submit = None
return MemoirPhase1ReadyOut(
ready=ready,
checked_segment_ids=ids,
pending_segment_ids=pending,
job_submitted_at_utc=job_submitted_at_utc,
elapsed_ms_since_submit=elapsed_ms_since_submit,
durations_ms=durations_ms,
)
async def submit_memoir_phase1_for_conversation(
self,
*,
conversation_id: str,
) -> MemoirSubmitOut:
"""本会话内尚待 Phase1 的 segment 合并提交 Celery与对话结束 flush 语义对齐)。"""
cid = (conversation_id or "").strip()
if not cid:
raise EvaluationBadRequestError("conversation_id is required")
conv = await self._db.get(Conversation, cid)
if not conv or conv.deleted_at is not None:
raise EvaluationNotFoundError("conversation not found")
uid = str(conv.user_id)
stmt = (
select(Segment.id)
.where(
Segment.conversation_id == cid,
Segment.processed.is_(False),
Segment.topic_category.is_(None),
)
.order_by(Segment.created_at.asc())
)
result = await self._db.execute(stmt)
segment_ids = [str(i) for i in result.scalars().all()]
if not segment_ids:
return MemoirSubmitOut(
conversation_id=cid,
user_id=uid,
segment_ids=[],
celery_task_id=None,
submitted_at_utc=None,
elapsed_ms=None,
)
t0 = time.perf_counter()
task_id = await background_runner.flush_pending(
uid, extra_segment_ids=segment_ids
)
elapsed_ms = max(0, int((time.perf_counter() - t0) * 1000))
submitted_at = await record_phase1_job_submitted(
cid,
celery_task_id=task_id,
segment_count=len(segment_ids),
)
return MemoirSubmitOut(
conversation_id=cid,
user_id=uid,
segment_ids=segment_ids,
celery_task_id=task_id,
submitted_at_utc=submitted_at,
elapsed_ms=elapsed_ms,
)