mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai> Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai> Co-authored-by: Lucas Marandat <31749711+lucasmrdt@users.noreply.github.com> Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com> Co-authored-by: Paul Cacheux <paul.cacheux@mistral.ai> Co-authored-by: Peter Evers <pevers90@gmail.com> Co-authored-by: Pierre Rossinès <pierre.rossines@mistral.ai> Co-authored-by: Pierre Rossinès <pierre.rossines@protonmail.com> Co-authored-by: Quentin <quentin.torroba@mistral.ai> Co-authored-by: Simon Van de Kerckhove <simon.vandekerckhove@mistral.ai> Co-authored-by: Val <102326092+vdeva@users.noreply.github.com> Co-authored-by: Vincent G <10739306+VinceOPS@users.noreply.github.com> Co-authored-by: Mistral Vibe <vibe@mistral.ai>
698 lines
27 KiB
Python
698 lines
27 KiB
Python
"""Test data for this module was generated using real LLM provider API responses,
|
|
with responses simplified and formatted to make them readable and maintainable.
|
|
|
|
To update or modify test parameters:
|
|
1. Make actual API calls to the target providers
|
|
2. Use the raw API responses as a base for updating test data
|
|
3. Simplify only where necessary for readability while preserving core structure
|
|
|
|
The closer test data remains to real API responses, the more reliable and accurate
|
|
the tests will be. Always prefer real API data over manually constructed examples.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import ClassVar, Literal
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
from mistralai.client.errors import SDKError
|
|
from mistralai.client.models import AssistantMessage
|
|
from mistralai.client.utils.retries import BackoffStrategy, RetryConfig
|
|
import pytest
|
|
import respx
|
|
|
|
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
|
|
from tests.backend.data.fireworks import (
|
|
SIMPLE_CONVERSATION_PARAMS as FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
|
|
STREAMED_SIMPLE_CONVERSATION_PARAMS as FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
|
STREAMED_TOOL_CONVERSATION_PARAMS as FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
|
|
TOOL_CONVERSATION_PARAMS as FIREWORKS_TOOL_CONVERSATION_PARAMS,
|
|
)
|
|
from tests.backend.data.mistral import (
|
|
SIMPLE_CONVERSATION_PARAMS as MISTRAL_SIMPLE_CONVERSATION_PARAMS,
|
|
STREAMED_SIMPLE_CONVERSATION_PARAMS as MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
|
STREAMED_TOOL_CONVERSATION_PARAMS as MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
|
|
TOOL_CONVERSATION_PARAMS as MISTRAL_TOOL_CONVERSATION_PARAMS,
|
|
)
|
|
from vibe.core.config import ModelConfig, ProviderConfig
|
|
from vibe.core.llm.backend.factory import BACKEND_FACTORY
|
|
from vibe.core.llm.backend.generic import GenericBackend
|
|
from vibe.core.llm.backend.mistral import MistralBackend, MistralMapper
|
|
from vibe.core.llm.exceptions import BackendError, BackendErrorBuilder
|
|
from vibe.core.llm.types import BackendLike
|
|
from vibe.core.types import Backend, FunctionCall, LLMChunk, LLMMessage, Role, ToolCall
|
|
from vibe.core.utils import get_user_agent
|
|
|
|
|
|
class TestBackend:
|
|
@staticmethod
|
|
def _build_fast_retry_config() -> RetryConfig:
|
|
return RetryConfig(
|
|
strategy="backoff",
|
|
backoff=BackoffStrategy(
|
|
initial_interval=1, max_interval=1, exponent=1, max_elapsed_time=1
|
|
),
|
|
retry_connection_errors=False,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"base_url,json_response,result_data",
|
|
[
|
|
*FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
|
|
*FIREWORKS_TOOL_CONVERSATION_PARAMS,
|
|
*MISTRAL_SIMPLE_CONVERSATION_PARAMS,
|
|
*MISTRAL_TOOL_CONVERSATION_PARAMS,
|
|
],
|
|
)
|
|
async def test_backend_complete(
|
|
self, base_url: Url, json_response: JsonResponse, result_data: ResultData
|
|
):
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
mock_api.post("/v1/chat/completions").mock(
|
|
return_value=httpx.Response(status_code=200, json=json_response)
|
|
)
|
|
provider = ProviderConfig(
|
|
name="provider_name",
|
|
api_base=f"{base_url}/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
|
|
BackendClasses = [
|
|
GenericBackend,
|
|
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
|
|
]
|
|
for BackendClass in BackendClasses:
|
|
backend: BackendLike = BackendClass(provider=provider)
|
|
model = ModelConfig(
|
|
name="model_name", provider="provider_name", alias="model_alias"
|
|
)
|
|
messages = [LLMMessage(role=Role.user, content="Just say hi")]
|
|
|
|
result = await backend.complete(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers=None,
|
|
)
|
|
|
|
assert result.message.content == result_data["message"]
|
|
assert result.usage is not None
|
|
assert (
|
|
result.usage.prompt_tokens == result_data["usage"]["prompt_tokens"]
|
|
)
|
|
assert (
|
|
result.usage.completion_tokens
|
|
== result_data["usage"]["completion_tokens"]
|
|
)
|
|
|
|
if result.message.tool_calls is None:
|
|
return
|
|
|
|
assert len(result.message.tool_calls) == len(result_data["tool_calls"])
|
|
for i, tool_call in enumerate[ToolCall](result.message.tool_calls):
|
|
assert (
|
|
tool_call.function.name == result_data["tool_calls"][i]["name"]
|
|
)
|
|
assert (
|
|
tool_call.function.arguments
|
|
== result_data["tool_calls"][i]["arguments"]
|
|
)
|
|
assert tool_call.index == result_data["tool_calls"][i]["index"]
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"base_url,chunks,result_data",
|
|
[
|
|
*FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
|
*FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
|
|
*MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
|
*MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
|
|
],
|
|
)
|
|
async def test_backend_complete_streaming(
|
|
self, base_url: Url, chunks: list[Chunk], result_data: list[ResultData]
|
|
):
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
mock_api.post("/v1/chat/completions").mock(
|
|
return_value=httpx.Response(
|
|
status_code=200,
|
|
stream=httpx.ByteStream(stream=b"\n\n".join(chunks)),
|
|
headers={"Content-Type": "text/event-stream"},
|
|
)
|
|
)
|
|
provider = ProviderConfig(
|
|
name="provider_name",
|
|
api_base=f"{base_url}/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
BackendClasses = [
|
|
GenericBackend,
|
|
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
|
|
]
|
|
for BackendClass in BackendClasses:
|
|
backend: BackendLike = BackendClass(provider=provider)
|
|
model = ModelConfig(
|
|
name="model_name", provider="provider_name", alias="model_alias"
|
|
)
|
|
|
|
messages = [
|
|
LLMMessage(role=Role.user, content="List files in current dir")
|
|
]
|
|
|
|
results: list[LLMChunk] = []
|
|
async for result in backend.complete_streaming(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers=None,
|
|
):
|
|
results.append(result)
|
|
|
|
for result, expected_result in zip(results, result_data, strict=True):
|
|
assert result.message.content == expected_result["message"]
|
|
assert result.usage is not None
|
|
assert (
|
|
result.usage.prompt_tokens
|
|
== expected_result["usage"]["prompt_tokens"]
|
|
)
|
|
assert (
|
|
result.usage.completion_tokens
|
|
== expected_result["usage"]["completion_tokens"]
|
|
)
|
|
|
|
if result.message.tool_calls is None:
|
|
continue
|
|
|
|
for i, tool_call in enumerate(result.message.tool_calls):
|
|
assert (
|
|
tool_call.function.name
|
|
== expected_result["tool_calls"][i]["name"]
|
|
)
|
|
assert (
|
|
tool_call.function.arguments
|
|
== expected_result["tool_calls"][i]["arguments"]
|
|
)
|
|
assert (
|
|
tool_call.index == expected_result["tool_calls"][i]["index"]
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"base_url,backend_class,response",
|
|
[
|
|
(
|
|
"https://api.fireworks.ai",
|
|
GenericBackend,
|
|
httpx.Response(status_code=500, text="Internal Server Error"),
|
|
),
|
|
(
|
|
"https://api.fireworks.ai",
|
|
GenericBackend,
|
|
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
|
|
),
|
|
(
|
|
"https://api.mistral.ai",
|
|
MistralBackend,
|
|
httpx.Response(status_code=500, text="Internal Server Error"),
|
|
),
|
|
(
|
|
"https://api.mistral.ai",
|
|
MistralBackend,
|
|
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
|
|
),
|
|
],
|
|
)
|
|
async def test_backend_complete_streaming_error(
|
|
self,
|
|
base_url: Url,
|
|
backend_class: type[MistralBackend | GenericBackend],
|
|
response: httpx.Response,
|
|
):
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
mock_api.post("/v1/chat/completions").mock(return_value=response)
|
|
provider = ProviderConfig(
|
|
name="provider_name",
|
|
api_base=f"{base_url}/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
backend = backend_class(provider=provider)
|
|
if isinstance(backend, MistralBackend):
|
|
backend._retry_config = self._build_fast_retry_config()
|
|
model = ModelConfig(
|
|
name="model_name", provider="provider_name", alias="model_alias"
|
|
)
|
|
messages = [LLMMessage(role=Role.user, content="Just say hi")]
|
|
with pytest.raises(BackendError) as e:
|
|
async for _ in backend.complete_streaming(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers=None,
|
|
):
|
|
pass
|
|
assert e.value.status == response.status_code
|
|
assert e.value.reason == response.reason_phrase
|
|
assert e.value.parsed_error is None
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"base_url,provider_name,expected_stream_options",
|
|
[
|
|
("https://api.fireworks.ai", "fireworks", {"include_usage": True}),
|
|
(
|
|
"https://api.mistral.ai",
|
|
"mistral",
|
|
{"include_usage": True, "stream_tool_calls": True},
|
|
),
|
|
],
|
|
)
|
|
async def test_backend_streaming_payload_includes_stream_options(
|
|
self, base_url: Url, provider_name: str, expected_stream_options: dict
|
|
):
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
route = mock_api.post("/v1/chat/completions").mock(
|
|
return_value=httpx.Response(
|
|
status_code=200,
|
|
stream=httpx.ByteStream(
|
|
b'data: {"choices": [{"delta": {"role": "assistant", "content": "hi"}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 10, "completion_tokens": 5}}\n\ndata: [DONE]\n\n'
|
|
),
|
|
headers={"Content-Type": "text/event-stream"},
|
|
)
|
|
)
|
|
provider = ProviderConfig(
|
|
name=provider_name, api_base=f"{base_url}/v1", api_key_env_var="API_KEY"
|
|
)
|
|
backend = GenericBackend(provider=provider)
|
|
model = ModelConfig(
|
|
name="model_name", provider=provider_name, alias="model_alias"
|
|
)
|
|
messages = [LLMMessage(role=Role.user, content="hi")]
|
|
|
|
async for _ in backend.complete_streaming(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers=None,
|
|
):
|
|
pass
|
|
|
|
assert route.called
|
|
request = route.calls.last.request
|
|
payload = json.loads(request.content)
|
|
|
|
assert payload["stream"] is True
|
|
assert payload["stream_options"] == expected_stream_options
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("backend_type", [Backend.MISTRAL, Backend.GENERIC])
|
|
async def test_backend_user_agent(self, backend_type: Backend):
|
|
user_agent = get_user_agent(backend_type)
|
|
base_url = "https://api.example.com"
|
|
json_response = {
|
|
"id": "fake_id_1234",
|
|
"created": 1234567890,
|
|
"model": "devstral-latest",
|
|
"usage": {
|
|
"prompt_tokens": 100,
|
|
"total_tokens": 300,
|
|
"completion_tokens": 200,
|
|
},
|
|
"object": "chat.completion",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"finish_reason": "stop",
|
|
"message": {
|
|
"role": "assistant",
|
|
"tool_calls": None,
|
|
"content": "Hey",
|
|
},
|
|
}
|
|
],
|
|
}
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
mock_api.post("/v1/chat/completions").mock(
|
|
return_value=httpx.Response(status_code=200, json=json_response)
|
|
)
|
|
|
|
provider = ProviderConfig(
|
|
name="provider_name",
|
|
api_base=f"{base_url}/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
backend = BACKEND_FACTORY[backend_type](provider=provider)
|
|
model = ModelConfig(
|
|
name="model_name", provider="provider_name", alias="model_alias"
|
|
)
|
|
messages = [LLMMessage(role=Role.user, content="Just say hi")]
|
|
|
|
await backend.complete(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers={"user-agent": user_agent},
|
|
)
|
|
|
|
assert mock_api.calls.last.request.headers["user-agent"] == user_agent
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("backend_type", [Backend.MISTRAL, Backend.GENERIC])
|
|
async def test_backend_user_agent_when_streaming(self, backend_type: Backend):
|
|
user_agent = get_user_agent(backend_type)
|
|
|
|
base_url = "https://api.example.com"
|
|
with respx.mock(base_url=base_url) as mock_api:
|
|
chunks = [
|
|
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"role":"assistant","content":"Hey"},"finish_reason":"stop"}]}'
|
|
]
|
|
mock_response = httpx.Response(
|
|
status_code=200,
|
|
stream=httpx.ByteStream(stream=b"\n\n".join(chunks)),
|
|
headers={"Content-Type": "text/event-stream"},
|
|
)
|
|
mock_api.post("/v1/chat/completions").mock(return_value=mock_response)
|
|
|
|
provider = ProviderConfig(
|
|
name="provider_name",
|
|
api_base=f"{base_url}/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
backend = BACKEND_FACTORY[backend_type](provider=provider)
|
|
model = ModelConfig(
|
|
name="model_name", provider="provider_name", alias="model_alias"
|
|
)
|
|
messages = [LLMMessage(role=Role.user, content="Just say hi")]
|
|
|
|
async for _ in backend.complete_streaming(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers={"user-agent": user_agent},
|
|
):
|
|
pass
|
|
|
|
assert mock_api.calls.last.request.headers["user-agent"] == user_agent
|
|
|
|
|
|
class TestMistralRetry:
|
|
@staticmethod
|
|
def _create_test_backend() -> MistralBackend:
|
|
provider = ProviderConfig(
|
|
name="test_provider",
|
|
api_base="https://api.mistral.ai/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
return MistralBackend(provider=provider)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_creation_includes_timeout_and_retry_config(self):
|
|
backend = self._create_test_backend()
|
|
|
|
with patch("vibe.core.llm.backend.mistral.Mistral") as mock_mistral_class:
|
|
mock_mistral_class.return_value = MagicMock()
|
|
backend._get_client()
|
|
mock_mistral_class.assert_called_once_with(
|
|
api_key=backend._api_key,
|
|
server_url=backend._server_url,
|
|
timeout_ms=720000,
|
|
retry_config=backend._retry_config,
|
|
)
|
|
|
|
|
|
class TestMistralMapperPrepareMessage:
|
|
"""Tests for MistralMapper.prepare_message thinking-block handling.
|
|
|
|
The Mistral API returns assistant messages with reasoning as a single
|
|
ThinkChunk (no trailing TextChunk when there is no text content). When
|
|
the mapper rebuilds the message for the next request it must NOT append
|
|
an empty TextChunk, otherwise the proxy's history-consistency check
|
|
sees a content mismatch on every turn and creates spurious conversation
|
|
segments.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def mapper(self) -> MistralMapper:
|
|
return MistralMapper()
|
|
|
|
def test_reasoning_only_no_empty_text_chunk(self, mapper: MistralMapper) -> None:
|
|
"""Assistant with reasoning_content but no text content should produce
|
|
only a ThinkChunk — no trailing empty TextChunk.
|
|
"""
|
|
msg = LLMMessage(
|
|
role=Role.assistant,
|
|
content=None,
|
|
reasoning_content="Let me think step by step.",
|
|
)
|
|
result = mapper.prepare_message(msg)
|
|
content = result.content
|
|
assert isinstance(content, list)
|
|
assert len(content) == 1
|
|
assert content[0].type == "thinking"
|
|
|
|
def test_reasoning_with_empty_string_content(self, mapper: MistralMapper) -> None:
|
|
"""content='' (empty string) should also not produce a trailing TextChunk."""
|
|
msg = LLMMessage(
|
|
role=Role.assistant, content="", reasoning_content="Thinking..."
|
|
)
|
|
result = mapper.prepare_message(msg)
|
|
content = result.content
|
|
assert isinstance(content, list)
|
|
assert len(content) == 1
|
|
assert content[0].type == "thinking"
|
|
|
|
def test_reasoning_with_text_content(self, mapper: MistralMapper) -> None:
|
|
"""When there is actual text content, both ThinkChunk and TextChunk
|
|
should be present.
|
|
"""
|
|
msg = LLMMessage(
|
|
role=Role.assistant,
|
|
content="Here is the answer.",
|
|
reasoning_content="Let me reason.",
|
|
)
|
|
result = mapper.prepare_message(msg)
|
|
content = result.content
|
|
assert isinstance(content, list)
|
|
assert len(content) == 2
|
|
assert content[0].type == "thinking"
|
|
assert content[1].type == "text"
|
|
assert content[1].text == "Here is the answer."
|
|
|
|
def test_reasoning_with_tool_calls_no_text(self, mapper: MistralMapper) -> None:
|
|
"""Reasoning + tool_calls but no text content — only ThinkChunk."""
|
|
msg = LLMMessage(
|
|
role=Role.assistant,
|
|
content=None,
|
|
reasoning_content="I should run a command.",
|
|
tool_calls=[
|
|
ToolCall(
|
|
id="tc_1",
|
|
index=0,
|
|
function=FunctionCall(name="bash", arguments='{"cmd": "ls"}'),
|
|
)
|
|
],
|
|
)
|
|
result = mapper.prepare_message(msg)
|
|
assert isinstance(result, AssistantMessage)
|
|
content = result.content
|
|
assert isinstance(content, list)
|
|
assert len(content) == 1
|
|
assert content[0].type == "thinking"
|
|
# Tool calls should still be present
|
|
assert isinstance(result.tool_calls, list)
|
|
assert len(result.tool_calls) == 1
|
|
assert result.tool_calls[0].function.name == "bash"
|
|
|
|
def test_no_reasoning_plain_string(self, mapper: MistralMapper) -> None:
|
|
"""Without reasoning_content, content is a plain string."""
|
|
msg = LLMMessage(role=Role.assistant, content="Hello!")
|
|
result = mapper.prepare_message(msg)
|
|
assert result.content == "Hello!"
|
|
|
|
|
|
class TestMistralBackendReasoningEffort:
|
|
"""Tests that MistralBackend correctly passes reasoning_effort to the SDK."""
|
|
|
|
@pytest.fixture
|
|
def backend(self) -> MistralBackend:
|
|
provider = ProviderConfig(
|
|
name="mistral",
|
|
api_base="https://api.mistral.ai/v1",
|
|
api_key_env_var="API_KEY",
|
|
)
|
|
return MistralBackend(provider=provider)
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
("thinking", "expected_effort", "expected_temperature"),
|
|
[
|
|
("off", None, 0.2),
|
|
("low", "none", 1.0),
|
|
("medium", "high", 1.0),
|
|
("high", "high", 1.0),
|
|
],
|
|
)
|
|
async def test_complete_passes_reasoning_effort(
|
|
self,
|
|
backend: MistralBackend,
|
|
thinking: Literal["off", "low", "medium", "high"],
|
|
expected_effort: str | None,
|
|
expected_temperature: float,
|
|
) -> None:
|
|
model = ModelConfig(
|
|
name="mistral-small-latest",
|
|
provider="mistral",
|
|
alias="mistral-small",
|
|
thinking=thinking,
|
|
)
|
|
messages = [LLMMessage(role=Role.user, content="hi")]
|
|
|
|
with patch.object(backend, "_get_client") as mock_get_client:
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "hello"
|
|
mock_response.choices[0].message.tool_calls = None
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
mock_client.chat.complete_async = AsyncMock(return_value=mock_response)
|
|
mock_get_client.return_value = mock_client
|
|
|
|
await backend.complete(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.2,
|
|
tools=None,
|
|
max_tokens=None,
|
|
tool_choice=None,
|
|
extra_headers=None,
|
|
)
|
|
|
|
call_kwargs = mock_client.chat.complete_async.call_args.kwargs
|
|
assert call_kwargs["reasoning_effort"] == expected_effort
|
|
assert call_kwargs["temperature"] == expected_temperature
|
|
|
|
|
|
class TestBuildHttpErrorBodyReading:
|
|
_MESSAGES: ClassVar[list[LLMMessage]] = [LLMMessage(role=Role.user, content="hi")]
|
|
_COMMON_KWARGS: ClassVar[dict] = dict(
|
|
provider="test",
|
|
endpoint="https://api.test.com",
|
|
model="test-model",
|
|
messages=_MESSAGES,
|
|
temperature=0.2,
|
|
has_tools=False,
|
|
tool_choice=None,
|
|
)
|
|
|
|
def _make_sdk_error(self, response: httpx.Response) -> SDKError:
|
|
return SDKError("sdk error", response)
|
|
|
|
def _make_http_status_error(
|
|
self, response: httpx.Response
|
|
) -> httpx.HTTPStatusError:
|
|
return httpx.HTTPStatusError(
|
|
"http error", request=response.request, response=response
|
|
)
|
|
|
|
def test_sdk_error_readable_body(self) -> None:
|
|
response = httpx.Response(
|
|
400,
|
|
json={"message": "invalid temperature"},
|
|
request=httpx.Request("POST", "https://api.test.com"),
|
|
)
|
|
err = BackendErrorBuilder.build_http_error(
|
|
error=self._make_sdk_error(response), **self._COMMON_KWARGS
|
|
)
|
|
assert err.status == 400
|
|
assert err.parsed_error == "invalid temperature"
|
|
assert "invalid temperature" in err.body_text
|
|
|
|
def test_http_status_error_readable_body(self) -> None:
|
|
response = httpx.Response(
|
|
400,
|
|
json={"message": "invalid temperature"},
|
|
request=httpx.Request("POST", "https://api.test.com"),
|
|
)
|
|
err = BackendErrorBuilder.build_http_error(
|
|
error=self._make_http_status_error(response), **self._COMMON_KWARGS
|
|
)
|
|
assert err.status == 400
|
|
assert err.parsed_error == "invalid temperature"
|
|
assert "invalid temperature" in err.body_text
|
|
|
|
def test_sdk_error_stream_response_falls_back_to_read(self) -> None:
|
|
response = httpx.Response(
|
|
400,
|
|
stream=httpx.ByteStream(b'{"message": "context too long"}'),
|
|
request=httpx.Request("POST", "https://api.test.com"),
|
|
)
|
|
sdk_err = SDKError(
|
|
"sdk error", response, body='{"message": "context too long"}'
|
|
)
|
|
err = BackendErrorBuilder.build_http_error(error=sdk_err, **self._COMMON_KWARGS)
|
|
assert err.parsed_error == "context too long"
|
|
assert "context too long" in err.body_text
|
|
|
|
def test_http_status_error_stream_response_falls_back_to_read(self) -> None:
|
|
response = httpx.Response(
|
|
400,
|
|
stream=httpx.ByteStream(b'{"message": "context too long"}'),
|
|
request=httpx.Request("POST", "https://api.test.com"),
|
|
)
|
|
err = BackendErrorBuilder.build_http_error(
|
|
error=self._make_http_status_error(response), **self._COMMON_KWARGS
|
|
)
|
|
assert err.parsed_error == "context too long"
|
|
assert "context too long" in err.body_text
|
|
|
|
def test_sdk_error_unreadable_response_falls_back_to_str(self) -> None:
|
|
response = MagicMock(spec=httpx.Response)
|
|
response.status_code = 400
|
|
response.reason_phrase = "Bad Request"
|
|
response.headers = {}
|
|
type(response).text = property(lambda self: (_ for _ in ()).throw(Exception))
|
|
response.read.side_effect = Exception("closed")
|
|
|
|
sdk_err = SDKError("sdk msg", response, body='{"message": "context too long"}')
|
|
err = BackendErrorBuilder.build_http_error(error=sdk_err, **self._COMMON_KWARGS)
|
|
assert err.body_text == '{"message": "context too long"}'
|
|
assert err.parsed_error == "context too long"
|
|
|
|
def test_http_status_error_unreadable_response_falls_back_to_str(self) -> None:
|
|
response = MagicMock(spec=httpx.Response)
|
|
response.status_code = 400
|
|
response.reason_phrase = "Bad Request"
|
|
response.headers = {}
|
|
type(response).text = property(lambda self: (_ for _ in ()).throw(Exception))
|
|
response.read.side_effect = Exception("closed")
|
|
response.request = httpx.Request("POST", "https://api.test.com")
|
|
|
|
http_err = httpx.HTTPStatusError(
|
|
"http error with details", request=response.request, response=response
|
|
)
|
|
err = BackendErrorBuilder.build_http_error(
|
|
error=http_err, **self._COMMON_KWARGS
|
|
)
|
|
assert "http error with details" in err.body_text
|