88 lines
2.7 KiB
Python
88 lines
2.7 KiB
Python
|
|
"""HTTP client for pending-confirmation and resolve endpoints."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Any
|
||
|
|
from urllib.parse import quote, urljoin
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class PendingConfirmationPayload:
|
||
|
|
surgery_id: str
|
||
|
|
confirmation_id: str
|
||
|
|
prompt_text: str
|
||
|
|
prompt_audio_mp3_base64: str
|
||
|
|
options: list[dict[str, Any]]
|
||
|
|
model_top1_label: str
|
||
|
|
model_top1_confidence: float
|
||
|
|
created_at: str
|
||
|
|
raw: dict[str, Any]
|
||
|
|
|
||
|
|
|
||
|
|
class ConfirmationApiClient:
|
||
|
|
def __init__(self, base_url: str, timeout: float = 60.0) -> None:
|
||
|
|
self._base = base_url.rstrip("/") + "/"
|
||
|
|
self._timeout = timeout
|
||
|
|
self._client = httpx.Client(timeout=timeout)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def base_url_normalized(self) -> str:
|
||
|
|
return self._base
|
||
|
|
|
||
|
|
def close(self) -> None:
|
||
|
|
self._client.close()
|
||
|
|
|
||
|
|
def _url(self, path: str) -> str:
|
||
|
|
return urljoin(self._base, path.lstrip("/"))
|
||
|
|
|
||
|
|
def get_pending(self, surgery_id: str) -> tuple[int, dict[str, Any] | str]:
|
||
|
|
url = self._url(f"client/surgeries/{surgery_id}/pending-confirmation")
|
||
|
|
r = self._client.get(url)
|
||
|
|
text = r.text
|
||
|
|
if not text:
|
||
|
|
return r.status_code, {}
|
||
|
|
try:
|
||
|
|
body: dict[str, Any] | str = json.loads(text)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
body = text
|
||
|
|
return r.status_code, body
|
||
|
|
|
||
|
|
def parse_pending(self, body: dict[str, Any]) -> PendingConfirmationPayload:
|
||
|
|
return PendingConfirmationPayload(
|
||
|
|
surgery_id=str(body.get("surgery_id", "")),
|
||
|
|
confirmation_id=str(body["confirmation_id"]),
|
||
|
|
prompt_text=str(body.get("prompt_text", "")),
|
||
|
|
prompt_audio_mp3_base64=str(body.get("prompt_audio_mp3_base64", "")),
|
||
|
|
options=list(body.get("options") or []),
|
||
|
|
model_top1_label=str(body.get("model_top1_label", "")),
|
||
|
|
model_top1_confidence=float(body.get("model_top1_confidence", 0.0)),
|
||
|
|
created_at=str(body.get("created_at", "")),
|
||
|
|
raw=body,
|
||
|
|
)
|
||
|
|
|
||
|
|
def post_resolve(
|
||
|
|
self,
|
||
|
|
surgery_id: str,
|
||
|
|
confirmation_id: str,
|
||
|
|
wav_bytes: bytes,
|
||
|
|
filename: str = "voice.wav",
|
||
|
|
) -> tuple[int, dict[str, Any] | str]:
|
||
|
|
cid_enc = quote(confirmation_id, safe="")
|
||
|
|
url = self._url(
|
||
|
|
f"client/surgeries/{surgery_id}/pending-confirmation/{cid_enc}/resolve"
|
||
|
|
)
|
||
|
|
files = {"audio": (filename, wav_bytes, "audio/wav")}
|
||
|
|
r = self._client.post(url, files=files)
|
||
|
|
text = r.text
|
||
|
|
if not text:
|
||
|
|
return r.status_code, {}
|
||
|
|
try:
|
||
|
|
body: dict[str, Any] | str = json.loads(text)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
body = text
|
||
|
|
return r.status_code, body
|