46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
"""实验进度 SSE(轮询 DB,轻量实现)。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
|
||
from fastapi import APIRouter, Header, Query
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from app.core.db import AsyncSessionLocal
|
||
from app.features.evaluation.admin_service import EvaluationAdminService
|
||
from app.features.evaluation.internal_auth import verify_internal_eval_key
|
||
|
||
router = APIRouter(tags=["internal-evaluation-stream"])
|
||
|
||
|
||
@router.get("/experiments/{experiment_id}/stream")
|
||
async def experiment_event_stream(
|
||
experiment_id: str,
|
||
key: str | None = Query(
|
||
default=None,
|
||
description="等同 X-Internal-Eval-Key,供 EventSource 使用",
|
||
),
|
||
x_internal_eval_key: str | None = Header(default=None, alias="X-Internal-Eval-Key"),
|
||
):
|
||
verify_internal_eval_key(
|
||
header_value=x_internal_eval_key,
|
||
query_value=key,
|
||
)
|
||
|
||
async def event_gen():
|
||
while True:
|
||
async with AsyncSessionLocal() as session:
|
||
svc = EvaluationAdminService(session)
|
||
payload = await svc.experiment_stream_snapshot(experiment_id)
|
||
if payload is None:
|
||
yield f"data: {json.dumps({'error': 'not_found'})}\n\n"
|
||
break
|
||
yield f"data: {json.dumps(payload, default=str)}\n\n"
|
||
if payload.get("status") in ("completed", "failed"):
|
||
break
|
||
await asyncio.sleep(1.0)
|
||
|
||
return StreamingResponse(event_gen(), media_type="text/event-stream")
|