mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Clément Drouin <clement.drouin@mistral.ai> Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai> Co-authored-by: Gauthier Guinet <43207538+Gguinet@users.noreply.github.com> Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai> Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com> Co-authored-by: Quentin <torroba.q@gmail.com> Co-authored-by: Simon <80467011+sorgfresser@users.noreply.github.com> Co-authored-by: Simon Van de Kerckhove <simon.vandekerckhove@mistral.ai> Co-authored-by: Vincent G <10739306+VinceOPS@users.noreply.github.com> Co-authored-by: angelapopopo <angele.lenglemetz@mistral.ai> Co-authored-by: Mistral Vibe <vibe@mistral.ai>
166 lines
4.8 KiB
Python
166 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from vibe.core.config import TranscribeModelConfig, TranscribeProviderConfig
|
|
from vibe.core.transcribe import (
|
|
MistralTranscribeClient,
|
|
TranscribeDone,
|
|
TranscribeError,
|
|
TranscribeSessionCreated,
|
|
TranscribeTextDelta,
|
|
)
|
|
|
|
|
|
def _make_provider() -> TranscribeProviderConfig:
|
|
return TranscribeProviderConfig(
|
|
name="mistral",
|
|
api_base="https://api.mistral.ai",
|
|
api_key_env_var="MISTRAL_API_KEY",
|
|
)
|
|
|
|
|
|
def _make_model() -> TranscribeModelConfig:
|
|
return TranscribeModelConfig(
|
|
name="mistral-small-transcribe",
|
|
alias="foo",
|
|
provider="mistral",
|
|
encoding="pcm_s16le",
|
|
sample_rate=16_000,
|
|
target_streaming_delay_ms=200,
|
|
)
|
|
|
|
|
|
async def _empty_audio_stream() -> AsyncIterator[bytes]:
|
|
return
|
|
yield
|
|
|
|
|
|
def _make_sdk_session_created(request_id: str = "test-request-id") -> MagicMock:
|
|
from mistralai.client.models import (
|
|
RealtimeTranscriptionSession,
|
|
RealtimeTranscriptionSessionCreated,
|
|
)
|
|
|
|
session = MagicMock(spec=RealtimeTranscriptionSession)
|
|
session.request_id = request_id
|
|
mock = MagicMock(spec=RealtimeTranscriptionSessionCreated)
|
|
mock.session = session
|
|
return mock
|
|
|
|
|
|
def _make_sdk_text_delta(text: str) -> MagicMock:
|
|
from mistralai.client.models import TranscriptionStreamTextDelta
|
|
|
|
m = MagicMock(spec=TranscriptionStreamTextDelta)
|
|
m.text = text
|
|
return m
|
|
|
|
|
|
def _make_sdk_done(text: str) -> MagicMock:
|
|
from mistralai.client.models import TranscriptionStreamDone
|
|
|
|
m = MagicMock(spec=TranscriptionStreamDone)
|
|
m.text = text
|
|
return m
|
|
|
|
|
|
def _make_sdk_error(message: str) -> MagicMock:
|
|
from mistralai.client.models import RealtimeTranscriptionError
|
|
|
|
m = MagicMock(spec=RealtimeTranscriptionError)
|
|
m.error = MagicMock()
|
|
m.error.message = message
|
|
return m
|
|
|
|
|
|
def _make_sdk_unknown() -> MagicMock:
|
|
from mistralai.extra.realtime import UnknownRealtimeEvent
|
|
|
|
return MagicMock(spec=UnknownRealtimeEvent)
|
|
|
|
|
|
async def _collect(client: MistralTranscribeClient) -> list[object]:
|
|
events: list[object] = []
|
|
async for event in client.transcribe(_empty_audio_stream()):
|
|
events.append(event)
|
|
return events
|
|
|
|
|
|
def _patch_sdk(sdk_events: list[object]) -> MagicMock:
|
|
async def _fake_stream(**_kwargs: object) -> AsyncIterator[object]:
|
|
for e in sdk_events:
|
|
yield e
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.audio.realtime.transcribe_stream = _fake_stream
|
|
return mock_client
|
|
|
|
|
|
class TestEventMapping:
|
|
@pytest.mark.asyncio
|
|
async def test_session_created(self) -> None:
|
|
mock_client = _patch_sdk([_make_sdk_session_created()])
|
|
client = MistralTranscribeClient(_make_provider(), _make_model())
|
|
client._client = mock_client
|
|
|
|
events = await _collect(client)
|
|
|
|
assert len(events) == 1
|
|
assert isinstance(events[0], TranscribeSessionCreated)
|
|
assert events[0].request_id == "test-request-id"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_text_delta(self) -> None:
|
|
mock_client = _patch_sdk([_make_sdk_text_delta("hello")])
|
|
client = MistralTranscribeClient(_make_provider(), _make_model())
|
|
client._client = mock_client
|
|
|
|
events = await _collect(client)
|
|
|
|
assert len(events) == 1
|
|
assert isinstance(events[0], TranscribeTextDelta)
|
|
assert events[0].text == "hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_done(self) -> None:
|
|
mock_client = _patch_sdk([_make_sdk_done("full text")])
|
|
client = MistralTranscribeClient(_make_provider(), _make_model())
|
|
client._client = mock_client
|
|
|
|
events = await _collect(client)
|
|
|
|
assert len(events) == 1
|
|
assert isinstance(events[0], TranscribeDone)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error(self) -> None:
|
|
mock_client = _patch_sdk([_make_sdk_error("something broke")])
|
|
client = MistralTranscribeClient(_make_provider(), _make_model())
|
|
client._client = mock_client
|
|
|
|
events = await _collect(client)
|
|
|
|
assert len(events) == 1
|
|
assert isinstance(events[0], TranscribeError)
|
|
assert events[0].message == "something broke"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_event_is_skipped(self) -> None:
|
|
mock_client = _patch_sdk([
|
|
_make_sdk_session_created(),
|
|
_make_sdk_unknown(),
|
|
_make_sdk_text_delta("hi"),
|
|
])
|
|
client = MistralTranscribeClient(_make_provider(), _make_model())
|
|
client._client = mock_client
|
|
|
|
events = await _collect(client)
|
|
|
|
assert len(events) == 2
|
|
assert isinstance(events[0], TranscribeSessionCreated)
|
|
assert isinstance(events[1], TranscribeTextDelta)
|