Files
mistral-vibe/tests/backend/test_backend.py
Quentin Torroba fa15fc977b Initial commit
Co-Authored-By: Quentin Torroba <quentin.torroba@mistral.ai>
Co-Authored-By: Laure Hugo <laure.hugo@mistral.ai>
Co-Authored-By: Benjamin Trom <benjamin.trom@mistral.ai>
Co-Authored-By: Mathias Gesbert <mathias.gesbert@ext.mistral.ai>
Co-Authored-By: Michel Thomazo <michel.thomazo@mistral.ai>
Co-Authored-By: Clément Drouin <clement.drouin@mistral.ai>
Co-Authored-By: Vincent Guilloux <vincent.guilloux@mistral.ai>
Co-Authored-By: Valentin Berard <val@mistral.ai>
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
2025-12-09 13:13:22 +01:00

249 lines
9.9 KiB
Python

"""Test data for this module was generated using real LLM provider API responses,
with responses simplified and formatted to make them readable and maintainable.
To update or modify test parameters:
1. Make actual API calls to the target providers
2. Use the raw API responses as a base for updating test data
3. Simplify only where necessary for readability while preserving core structure
The closer test data remains to real API responses, the more reliable and accurate
the tests will be. Always prefer real API data over manually constructed examples.
"""
from __future__ import annotations
import httpx
import pytest
import respx
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
from tests.backend.data.fireworks import (
SIMPLE_CONVERSATION_PARAMS as FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
STREAMED_SIMPLE_CONVERSATION_PARAMS as FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
STREAMED_TOOL_CONVERSATION_PARAMS as FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
TOOL_CONVERSATION_PARAMS as FIREWORKS_TOOL_CONVERSATION_PARAMS,
)
from tests.backend.data.mistral import (
SIMPLE_CONVERSATION_PARAMS as MISTRAL_SIMPLE_CONVERSATION_PARAMS,
STREAMED_SIMPLE_CONVERSATION_PARAMS as MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
STREAMED_TOOL_CONVERSATION_PARAMS as MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
TOOL_CONVERSATION_PARAMS as MISTRAL_TOOL_CONVERSATION_PARAMS,
)
from vibe.core.config import ModelConfig, ProviderConfig
from vibe.core.llm.backend.generic import GenericBackend
from vibe.core.llm.backend.mistral import MistralBackend
from vibe.core.llm.exceptions import BackendError
from vibe.core.llm.types import BackendLike
from vibe.core.types import LLMChunk, LLMMessage, Role, ToolCall
class TestBackend:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"base_url,json_response,result_data",
[
*FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
*FIREWORKS_TOOL_CONVERSATION_PARAMS,
*MISTRAL_SIMPLE_CONVERSATION_PARAMS,
*MISTRAL_TOOL_CONVERSATION_PARAMS,
],
)
async def test_backend_complete(
self, base_url: Url, json_response: JsonResponse, result_data: ResultData
):
with respx.mock(base_url=base_url) as mock_api:
mock_api.post("/v1/chat/completions").mock(
return_value=httpx.Response(status_code=200, json=json_response)
)
provider = ProviderConfig(
name="provider_name",
api_base=f"{base_url}/v1",
api_key_env_var="API_KEY",
)
BackendClasses = [
GenericBackend,
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
]
for BackendClass in BackendClasses:
backend: BackendLike = BackendClass(provider=provider)
model = ModelConfig(
name="model_name", provider="provider_name", alias="model_alias"
)
messages = [LLMMessage(role=Role.user, content="Just say hi")]
result = await backend.complete(
model=model,
messages=messages,
temperature=0.2,
tools=None,
max_tokens=None,
tool_choice=None,
extra_headers=None,
)
assert result.message.content == result_data["message"]
assert result.finish_reason == result_data["finish_reason"]
assert result.usage is not None
assert (
result.usage.prompt_tokens == result_data["usage"]["prompt_tokens"]
)
assert (
result.usage.completion_tokens
== result_data["usage"]["completion_tokens"]
)
if result.message.tool_calls is None:
return
assert len(result.message.tool_calls) == len(result_data["tool_calls"])
for i, tool_call in enumerate[ToolCall](result.message.tool_calls):
assert (
tool_call.function.name == result_data["tool_calls"][i]["name"]
)
assert (
tool_call.function.arguments
== result_data["tool_calls"][i]["arguments"]
)
assert tool_call.index == result_data["tool_calls"][i]["index"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"base_url,chunks,result_data",
[
*FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
*FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
*MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
*MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
],
)
async def test_backend_complete_streaming(
self, base_url: Url, chunks: list[Chunk], result_data: list[ResultData]
):
with respx.mock(base_url=base_url) as mock_api:
mock_api.post("/v1/chat/completions").mock(
return_value=httpx.Response(
status_code=200,
stream=httpx.ByteStream(stream=b"\n\n".join(chunks)),
headers={"Content-Type": "text/event-stream"},
)
)
provider = ProviderConfig(
name="provider_name",
api_base=f"{base_url}/v1",
api_key_env_var="API_KEY",
)
BackendClasses = [
GenericBackend,
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
]
for BackendClass in BackendClasses:
backend: BackendLike = BackendClass(provider=provider)
model = ModelConfig(
name="model_name", provider="provider_name", alias="model_alias"
)
messages = [
LLMMessage(role=Role.user, content="List files in current dir")
]
results: list[LLMChunk] = []
async for result in backend.complete_streaming(
model=model,
messages=messages,
temperature=0.2,
tools=None,
max_tokens=None,
tool_choice=None,
extra_headers=None,
):
results.append(result)
for result, expected_result in zip(results, result_data, strict=True):
assert result.message.content == expected_result["message"]
assert result.finish_reason == expected_result["finish_reason"]
assert result.usage is not None
assert (
result.usage.prompt_tokens
== expected_result["usage"]["prompt_tokens"]
)
assert (
result.usage.completion_tokens
== expected_result["usage"]["completion_tokens"]
)
if result.message.tool_calls is None:
continue
for i, tool_call in enumerate(result.message.tool_calls):
assert (
tool_call.function.name
== expected_result["tool_calls"][i]["name"]
)
assert (
tool_call.function.arguments
== expected_result["tool_calls"][i]["arguments"]
)
assert (
tool_call.index == expected_result["tool_calls"][i]["index"]
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"base_url,backend_class,response",
[
(
"https://api.fireworks.ai",
GenericBackend,
httpx.Response(status_code=500, text="Internal Server Error"),
),
(
"https://api.fireworks.ai",
GenericBackend,
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
),
(
"https://api.mistral.ai",
MistralBackend,
httpx.Response(status_code=500, text="Internal Server Error"),
),
(
"https://api.mistral.ai",
MistralBackend,
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
),
],
)
async def test_backend_complete_streaming_error(
self,
base_url: Url,
backend_class: type[MistralBackend | GenericBackend],
response: httpx.Response,
):
with respx.mock(base_url=base_url) as mock_api:
mock_api.post("/v1/chat/completions").mock(return_value=response)
provider = ProviderConfig(
name="provider_name",
api_base=f"{base_url}/v1",
api_key_env_var="API_KEY",
)
backend = backend_class(provider=provider)
model = ModelConfig(
name="model_name", provider="provider_name", alias="model_alias"
)
messages = [LLMMessage(role=Role.user, content="Just say hi")]
with pytest.raises(BackendError) as e:
async for _ in backend.complete_streaming(
model=model,
messages=messages,
temperature=0.2,
tools=None,
max_tokens=None,
tool_choice=None,
extra_headers=None,
):
pass
assert e.value.status == response.status_code
assert e.value.reason == response.reason_phrase
assert e.value.parsed_error is None