mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-Authored-By: Quentin Torroba <quentin.torroba@mistral.ai> Co-Authored-By: Laure Hugo <laure.hugo@mistral.ai> Co-Authored-By: Benjamin Trom <benjamin.trom@mistral.ai> Co-Authored-By: Mathias Gesbert <mathias.gesbert@ext.mistral.ai> Co-Authored-By: Michel Thomazo <michel.thomazo@mistral.ai> Co-Authored-By: Clément Drouin <clement.drouin@mistral.ai> Co-Authored-By: Vincent Guilloux <vincent.guilloux@mistral.ai> Co-Authored-By: Valentin Berard <val@mistral.ai> Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
249 lines
9.9 KiB
Python
249 lines
9.9 KiB
Python
"""Test data for this module was generated using real LLM provider API responses,
|
|
with responses simplified and formatted to make them readable and maintainable.
|
|
|
|
To update or modify test parameters:
|
|
1. Make actual API calls to the target providers
|
|
2. Use the raw API responses as a base for updating test data
|
|
3. Simplify only where necessary for readability while preserving core structure
|
|
|
|
The closer test data remains to real API responses, the more reliable and accurate
|
|
the tests will be. Always prefer real API data over manually constructed examples.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import httpx
|
|
import pytest
|
|
import respx
|
|
|
|
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
|
|
from tests.backend.data.fireworks import (
|
|
SIMPLE_CONVERSATION_PARAMS as FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
|
|
STREAMED_SIMPLE_CONVERSATION_PARAMS as FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
|
STREAMED_TOOL_CONVERSATION_PARAMS as FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
|
|
TOOL_CONVERSATION_PARAMS as FIREWORKS_TOOL_CONVERSATION_PARAMS,
|
|
)
|
|
from tests.backend.data.mistral import (
|
|
SIMPLE_CONVERSATION_PARAMS as MISTRAL_SIMPLE_CONVERSATION_PARAMS,
|
|
STREAMED_SIMPLE_CONVERSATION_PARAMS as MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
|
STREAMED_TOOL_CONVERSATION_PARAMS as MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
|
|
TOOL_CONVERSATION_PARAMS as MISTRAL_TOOL_CONVERSATION_PARAMS,
|
|
)
|
|
from vibe.core.config import ModelConfig, ProviderConfig
|
|
from vibe.core.llm.backend.generic import GenericBackend
|
|
from vibe.core.llm.backend.mistral import MistralBackend
|
|
from vibe.core.llm.exceptions import BackendError
|
|
from vibe.core.llm.types import BackendLike
|
|
from vibe.core.types import LLMChunk, LLMMessage, Role, ToolCall
|
|
|
|
|
|
class TestBackend:
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"base_url,json_response,result_data",
|
|
[
|
|
*FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
|
|
*FIREWORKS_TOOL_CONVERSATION_PARAMS,
|
|
*MISTRAL_SIMPLE_CONVERSATION_PARAMS,
|
|
*MISTRAL_TOOL_CONVERSATION_PARAMS,
|
|
],
|
|
)
|
|
async def test_backend_complete(
|
|
self, base_url: Url, json_response: JsonResponse, result_data: ResultData
|
|
):
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
mock_api.post("/v1/chat/completions").mock(
|
|
return_value=httpx.Response(status_code=200, json=json_response)
|
|
)
|
|
provider = ProviderConfig(
|
|
name="provider_name",
|
|
api_base=f"{base_url}/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
|
|
BackendClasses = [
|
|
GenericBackend,
|
|
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
|
|
]
|
|
for BackendClass in BackendClasses:
|
|
backend: BackendLike = BackendClass(provider=provider)
|
|
model = ModelConfig(
|
|
name="model_name", provider="provider_name", alias="model_alias"
|
|
)
|
|
messages = [LLMMessage(role=Role.user, content="Just say hi")]
|
|
|
|
result = await backend.complete(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers=None,
|
|
)
|
|
|
|
assert result.message.content == result_data["message"]
|
|
assert result.finish_reason == result_data["finish_reason"]
|
|
assert result.usage is not None
|
|
assert (
|
|
result.usage.prompt_tokens == result_data["usage"]["prompt_tokens"]
|
|
)
|
|
assert (
|
|
result.usage.completion_tokens
|
|
== result_data["usage"]["completion_tokens"]
|
|
)
|
|
|
|
if result.message.tool_calls is None:
|
|
return
|
|
|
|
assert len(result.message.tool_calls) == len(result_data["tool_calls"])
|
|
for i, tool_call in enumerate[ToolCall](result.message.tool_calls):
|
|
assert (
|
|
tool_call.function.name == result_data["tool_calls"][i]["name"]
|
|
)
|
|
assert (
|
|
tool_call.function.arguments
|
|
== result_data["tool_calls"][i]["arguments"]
|
|
)
|
|
assert tool_call.index == result_data["tool_calls"][i]["index"]
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"base_url,chunks,result_data",
|
|
[
|
|
*FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
|
*FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
|
|
*MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
|
*MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
|
|
],
|
|
)
|
|
async def test_backend_complete_streaming(
|
|
self, base_url: Url, chunks: list[Chunk], result_data: list[ResultData]
|
|
):
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
mock_api.post("/v1/chat/completions").mock(
|
|
return_value=httpx.Response(
|
|
status_code=200,
|
|
stream=httpx.ByteStream(stream=b"\n\n".join(chunks)),
|
|
headers={"Content-Type": "text/event-stream"},
|
|
)
|
|
)
|
|
provider = ProviderConfig(
|
|
name="provider_name",
|
|
api_base=f"{base_url}/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
BackendClasses = [
|
|
GenericBackend,
|
|
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
|
|
]
|
|
for BackendClass in BackendClasses:
|
|
backend: BackendLike = BackendClass(provider=provider)
|
|
model = ModelConfig(
|
|
name="model_name", provider="provider_name", alias="model_alias"
|
|
)
|
|
|
|
messages = [
|
|
LLMMessage(role=Role.user, content="List files in current dir")
|
|
]
|
|
|
|
results: list[LLMChunk] = []
|
|
async for result in backend.complete_streaming(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers=None,
|
|
):
|
|
results.append(result)
|
|
|
|
for result, expected_result in zip(results, result_data, strict=True):
|
|
assert result.message.content == expected_result["message"]
|
|
assert result.finish_reason == expected_result["finish_reason"]
|
|
assert result.usage is not None
|
|
assert (
|
|
result.usage.prompt_tokens
|
|
== expected_result["usage"]["prompt_tokens"]
|
|
)
|
|
assert (
|
|
result.usage.completion_tokens
|
|
== expected_result["usage"]["completion_tokens"]
|
|
)
|
|
|
|
if result.message.tool_calls is None:
|
|
continue
|
|
|
|
for i, tool_call in enumerate(result.message.tool_calls):
|
|
assert (
|
|
tool_call.function.name
|
|
== expected_result["tool_calls"][i]["name"]
|
|
)
|
|
assert (
|
|
tool_call.function.arguments
|
|
== expected_result["tool_calls"][i]["arguments"]
|
|
)
|
|
assert (
|
|
tool_call.index == expected_result["tool_calls"][i]["index"]
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"base_url,backend_class,response",
|
|
[
|
|
(
|
|
"https://api.fireworks.ai",
|
|
GenericBackend,
|
|
httpx.Response(status_code=500, text="Internal Server Error"),
|
|
),
|
|
(
|
|
"https://api.fireworks.ai",
|
|
GenericBackend,
|
|
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
|
|
),
|
|
(
|
|
"https://api.mistral.ai",
|
|
MistralBackend,
|
|
httpx.Response(status_code=500, text="Internal Server Error"),
|
|
),
|
|
(
|
|
"https://api.mistral.ai",
|
|
MistralBackend,
|
|
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
|
|
),
|
|
],
|
|
)
|
|
async def test_backend_complete_streaming_error(
|
|
self,
|
|
base_url: Url,
|
|
backend_class: type[MistralBackend | GenericBackend],
|
|
response: httpx.Response,
|
|
):
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
mock_api.post("/v1/chat/completions").mock(return_value=response)
|
|
provider = ProviderConfig(
|
|
name="provider_name",
|
|
api_base=f"{base_url}/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
backend = backend_class(provider=provider)
|
|
model = ModelConfig(
|
|
name="model_name", provider="provider_name", alias="model_alias"
|
|
)
|
|
messages = [LLMMessage(role=Role.user, content="Just say hi")]
|
|
with pytest.raises(BackendError) as e:
|
|
async for _ in backend.complete_streaming(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers=None,
|
|
):
|
|
pass
|
|
assert e.value.status == response.status_code
|
|
assert e.value.reason == response.reason_phrase
|
|
assert e.value.parsed_error is None
|