62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
|
|
import asyncio
|
||
|
|
import sys
|
||
|
|
from types import SimpleNamespace
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from app.adapters.asr.whisper_local import (
|
||
|
|
WhisperASRProvider,
|
||
|
|
_looks_like_subtitle_hallucination,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def test_subtitle_watermark_detection() -> None:
|
||
|
|
assert _looks_like_subtitle_hallucination("字幕by索兰娅") is True
|
||
|
|
assert _looks_like_subtitle_hallucination("今天想聊聊童年往事") is False
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_transcribe_retries_decode_audio_after_discarded_pass2(
|
||
|
|
monkeypatch: pytest.MonkeyPatch,
|
||
|
|
) -> None:
|
||
|
|
class DummyModel:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self.calls: list[object] = []
|
||
|
|
|
||
|
|
def transcribe(self, audio: object, **_: object):
|
||
|
|
self.calls.append(audio)
|
||
|
|
n = len(self.calls)
|
||
|
|
if n == 1:
|
||
|
|
return iter([]), SimpleNamespace()
|
||
|
|
if n == 2:
|
||
|
|
return iter([SimpleNamespace(text="字幕by索兰娅")]), SimpleNamespace()
|
||
|
|
if n == 3:
|
||
|
|
assert audio == "decoded-audio"
|
||
|
|
return (
|
||
|
|
iter([SimpleNamespace(text="你好,今天想聊聊童年。")]),
|
||
|
|
SimpleNamespace(),
|
||
|
|
)
|
||
|
|
raise AssertionError(f"unexpected transcribe call #{n}")
|
||
|
|
|
||
|
|
async def fake_to_thread(fn):
|
||
|
|
return fn()
|
||
|
|
|
||
|
|
def fake_decode_audio(_: str, sampling_rate: int = 16000):
|
||
|
|
assert sampling_rate == 16000
|
||
|
|
return "decoded-audio"
|
||
|
|
|
||
|
|
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
|
||
|
|
monkeypatch.setitem(
|
||
|
|
sys.modules,
|
||
|
|
"faster_whisper",
|
||
|
|
SimpleNamespace(decode_audio=fake_decode_audio),
|
||
|
|
)
|
||
|
|
|
||
|
|
provider = WhisperASRProvider()
|
||
|
|
provider._model = DummyModel()
|
||
|
|
|
||
|
|
text = await provider.transcribe(b"fake-audio", format="m4a")
|
||
|
|
|
||
|
|
assert text == "你好,今天想聊聊童年。"
|
||
|
|
assert len(provider._model.calls) == 3
|