173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
"""内部评测:按 App 一致路径回放用户轮次(segment + orchestrator + memoir 队列)。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import secrets
|
||
import uuid
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.db import utc_now
|
||
from app.core.logging import get_logger
|
||
from app.core.security import hash_password
|
||
from app.features.auth import repo as auth_repo
|
||
from app.features.conversation.models import Conversation
|
||
from app.features.conversation.service import ConversationService
|
||
from app.features.conversation.ws.pipeline import (
|
||
background_runner,
|
||
process_user_message,
|
||
)
|
||
from app.features.evaluation.errors import (
|
||
EvaluationBadRequestError,
|
||
EvaluationNotFoundError,
|
||
)
|
||
from app.features.evaluation.user_export_fixtures import read_user_export_fixture
|
||
from app.features.quota.service import QuotaService
|
||
from app.features.user.models import User
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class ReplayConversationService:
|
||
def __init__(self, db: AsyncSession, quota_service: QuotaService) -> None:
|
||
self._db = db
|
||
self._quota = quota_service
|
||
|
||
async def create_eval_sandbox(self) -> tuple[str, str, str, str]:
|
||
"""新建仅用于评测的临时用户(唯一伪手机号)+ 新会话。"""
|
||
user_id = str(uuid.uuid4())
|
||
phone: str | None = None
|
||
for _ in range(8):
|
||
candidate = f"eval_{secrets.token_hex(10)}"
|
||
existing = await auth_repo.get_user_by_phone(candidate, self._db)
|
||
if not existing:
|
||
phone = candidate
|
||
break
|
||
if not phone:
|
||
raise EvaluationBadRequestError("could not allocate eval phone")
|
||
|
||
user = User(
|
||
id=user_id,
|
||
phone=phone,
|
||
password_hash=hash_password(secrets.token_urlsafe(24)),
|
||
nickname="评测临时用户",
|
||
subscription_type="free",
|
||
created_at=utc_now(),
|
||
)
|
||
await auth_repo.create_user(user, self._db)
|
||
await self._db.commit()
|
||
await self._db.refresh(user)
|
||
|
||
conversation_id = str(uuid.uuid4())
|
||
conv_service = ConversationService(self._db, self._quota)
|
||
conv, err = await conv_service.ensure_ws_connection(conversation_id, user_id)
|
||
if err or not conv:
|
||
raise EvaluationBadRequestError(err or "failed to create conversation")
|
||
|
||
logger.info(
|
||
"eval sandbox user_id={} phone={} conversation_id={}",
|
||
user_id,
|
||
phone,
|
||
conversation_id,
|
||
)
|
||
return user_id, conversation_id, phone, user.nickname
|
||
|
||
async def bootstrap_conversation(self, user_id: str) -> str:
|
||
uid = (user_id or "").strip()
|
||
if not uid:
|
||
raise EvaluationBadRequestError("user_id is required")
|
||
user = await self._db.get(User, uid)
|
||
if not user:
|
||
raise EvaluationBadRequestError("user not found")
|
||
conversation_id = str(uuid.uuid4())
|
||
conv_service = ConversationService(self._db, self._quota)
|
||
conv, err = await conv_service.ensure_ws_connection(conversation_id, uid)
|
||
if err or not conv:
|
||
raise EvaluationBadRequestError(err or "failed to create conversation")
|
||
logger.info(
|
||
"eval replay bootstrap conversation_id={} user_id={}",
|
||
conversation_id,
|
||
uid,
|
||
)
|
||
return conversation_id
|
||
|
||
async def replay_fixture(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
fixture_filename: str,
|
||
flush_memoir_after: bool,
|
||
skip_tts: bool,
|
||
) -> tuple[int, list[str]]:
|
||
try:
|
||
turns, _ = read_user_export_fixture(fixture_filename)
|
||
except ValueError as e:
|
||
raise EvaluationBadRequestError(str(e)) from e
|
||
except FileNotFoundError:
|
||
raise EvaluationNotFoundError("fixture not found") from None
|
||
utterances = [u.strip() for u, _ in turns if (u or "").strip()]
|
||
if not utterances:
|
||
raise EvaluationBadRequestError("fixture produced no user utterances")
|
||
n = await self.replay_utterances(
|
||
conversation_id=conversation_id,
|
||
utterances=utterances,
|
||
flush_memoir_after=flush_memoir_after,
|
||
skip_tts=skip_tts,
|
||
)
|
||
return n, utterances
|
||
|
||
async def replay_utterances(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
utterances: list[str],
|
||
flush_memoir_after: bool,
|
||
skip_tts: bool,
|
||
) -> int:
|
||
cid = (conversation_id or "").strip()
|
||
if not cid:
|
||
raise EvaluationBadRequestError("conversation_id is required")
|
||
conv = await self._db.get(Conversation, cid)
|
||
if not conv or conv.deleted_at is not None:
|
||
raise EvaluationNotFoundError("conversation not found")
|
||
user = await self._db.get(User, conv.user_id)
|
||
if not user:
|
||
raise EvaluationBadRequestError("user not found for conversation")
|
||
|
||
conv_service = ConversationService(self._db, self._quota)
|
||
count = 0
|
||
for raw in utterances:
|
||
text = (raw or "").strip()
|
||
if not text:
|
||
continue
|
||
segment = await conv_service.create_user_segment(conv, conv.user_id, text)
|
||
ts = segment.created_at or conv.last_message_at
|
||
await background_runner.queue_message(
|
||
conv.user_id,
|
||
segment.id,
|
||
text_char_count=len(text),
|
||
)
|
||
await process_user_message(
|
||
conversation_id=cid,
|
||
user_message=text,
|
||
conversation=conv,
|
||
segment=segment,
|
||
db=self._db,
|
||
user=user,
|
||
user_message_timestamp=ts,
|
||
force_skip_tts=skip_tts,
|
||
)
|
||
count += 1
|
||
|
||
if flush_memoir_after and conv.user_id:
|
||
await background_runner.flush_pending(conv.user_id)
|
||
|
||
logger.info(
|
||
"eval replay done conversation_id={} turns={} flush={} skip_tts={}",
|
||
cid,
|
||
count,
|
||
flush_memoir_after,
|
||
skip_tts,
|
||
)
|
||
return count
|