Files
mistral-vibe/vibe/core/llm/backend/mistral.py
Mathias Gesbert e9a9217cc8 v2.7.4 (#579)
Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai>
Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai>
Co-authored-by: Lucas Marandat <31749711+lucasmrdt@users.noreply.github.com>
Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com>
Co-authored-by: Paul Cacheux <paul.cacheux@mistral.ai>
Co-authored-by: Peter Evers <pevers90@gmail.com>
Co-authored-by: Pierre Rossinès <pierre.rossines@mistral.ai>
Co-authored-by: Pierre Rossinès <pierre.rossines@protonmail.com>
Co-authored-by: Quentin <quentin.torroba@mistral.ai>
Co-authored-by: Simon Van de Kerckhove <simon.vandekerckhove@mistral.ai>
Co-authored-by: Val <102326092+vdeva@users.noreply.github.com>
Co-authored-by: Vincent G <10739306+VinceOPS@users.noreply.github.com>
Co-authored-by: Mistral Vibe <vibe@mistral.ai>
2026-04-09 18:40:46 +02:00

441 lines
15 KiB
Python

from __future__ import annotations
from collections.abc import AsyncGenerator, Sequence
import json
import os
import types
from typing import TYPE_CHECKING, Literal, NamedTuple, cast
import httpx
from mistralai.client import Mistral
from mistralai.client.errors import SDKError
from mistralai.client.models import (
AssistantMessage,
AssistantMessageContent,
ChatCompletionRequestMessage,
ChatCompletionStreamRequestToolChoice,
ContentChunk,
FileChunk,
Function,
FunctionCall as MistralFunctionCall,
FunctionName,
SystemMessage,
TextChunk,
ThinkChunk,
Tool,
ToolCall as MistralToolCall,
ToolChoice,
ToolChoiceEnum,
ToolMessage,
UserMessage,
)
from mistralai.client.utils.retries import BackoffStrategy, RetryConfig
from vibe.core.llm.exceptions import BackendErrorBuilder
from vibe.core.llm.message_utils import merge_consecutive_user_messages
from vibe.core.types import (
AvailableTool,
Content,
FunctionCall,
LLMChunk,
LLMMessage,
LLMUsage,
Role,
StrToolChoice,
ToolCall,
)
from vibe.core.utils import get_server_url_from_api_base
if TYPE_CHECKING:
from vibe.core.config import ModelConfig, ProviderConfig
class ParsedContent(NamedTuple):
content: Content
reasoning_content: Content | None
class MistralMapper:
def prepare_message(self, msg: LLMMessage) -> ChatCompletionRequestMessage:
match msg.role:
case Role.system:
return SystemMessage(role="system", content=msg.content or "")
case Role.user:
return UserMessage(role="user", content=msg.content)
case Role.assistant:
content: AssistantMessageContent
if msg.reasoning_content:
chunks: list[ContentChunk] = [
ThinkChunk(
type="thinking",
thinking=[
TextChunk(type="text", text=msg.reasoning_content)
],
)
]
if msg.content:
chunks.append(TextChunk(type="text", text=msg.content))
content = chunks
else:
content = msg.content or ""
return AssistantMessage(
role="assistant",
content=content,
tool_calls=[
MistralToolCall(
function=MistralFunctionCall(
name=tc.function.name or "",
arguments=tc.function.arguments or "",
),
id=tc.id,
type=tc.type,
index=tc.index,
)
for tc in msg.tool_calls or []
],
)
case Role.tool:
return ToolMessage(
role="tool",
content=msg.content,
tool_call_id=msg.tool_call_id,
name=msg.name,
)
def prepare_tool(self, tool: AvailableTool) -> Tool:
return Tool(
type="function",
function=Function(
name=tool.function.name,
description=tool.function.description,
parameters=tool.function.parameters,
),
)
def prepare_tool_choice(
self, tool_choice: StrToolChoice | AvailableTool
) -> ChatCompletionStreamRequestToolChoice:
if isinstance(tool_choice, str):
return cast(ToolChoiceEnum, tool_choice)
return ToolChoice(
type="function", function=FunctionName(name=tool_choice.function.name)
)
def _extract_thinking_text(self, chunk: ThinkChunk) -> str:
thinking_content = getattr(chunk, "thinking", None)
if not thinking_content:
return ""
parts = []
for inner in thinking_content:
if hasattr(inner, "type") and inner.type == "text":
parts.append(getattr(inner, "text", ""))
elif isinstance(inner, str):
parts.append(inner)
return "".join(parts)
def parse_content(self, content: AssistantMessageContent) -> ParsedContent:
if isinstance(content, str):
return ParsedContent(content=content, reasoning_content=None)
concat_content = ""
concat_reasoning = ""
for chunk in content:
if isinstance(chunk, FileChunk):
continue
if isinstance(chunk, TextChunk):
concat_content += chunk.text
elif isinstance(chunk, ThinkChunk):
concat_reasoning += self._extract_thinking_text(chunk)
return ParsedContent(
content=concat_content,
reasoning_content=concat_reasoning if concat_reasoning else None,
)
def parse_tool_calls(self, tool_calls: list[MistralToolCall]) -> list[ToolCall]:
return [
ToolCall(
id=tool_call.id,
function=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments
if isinstance(tool_call.function.arguments, str)
else json.dumps(tool_call.function.arguments, ensure_ascii=False),
),
index=tool_call.index,
)
for tool_call in tool_calls
]
ReasoningEffortValue = Literal["none", "high"]
_THINKING_TO_REASONING_EFFORT: dict[str, ReasoningEffortValue] = {
"low": "none",
"medium": "high",
"high": "high",
}
class MistralBackend:
def __init__(self, provider: ProviderConfig, timeout: float = 720.0) -> None:
self._client: Mistral | None = None
self._provider = provider
self._mapper = MistralMapper()
self._api_key = (
os.getenv(self._provider.api_key_env_var)
if self._provider.api_key_env_var
else None
)
reasoning_field = getattr(provider, "reasoning_field_name", "reasoning_content")
if reasoning_field != "reasoning_content":
raise ValueError(
f"Mistral backend does not support custom reasoning_field_name "
f"(got '{reasoning_field}'). Mistral uses ThinkChunk for reasoning."
)
# Mistral SDK takes server URL without api version as input
server_url = get_server_url_from_api_base(self._provider.api_base)
if not server_url:
raise ValueError(
f"Invalid API base URL: {self._provider.api_base}. "
"Expected format: <server_url>/v<api_version>"
)
self._server_url = server_url
self._timeout = timeout
self._retry_config = self._build_retry_config()
def _build_retry_config(self) -> RetryConfig:
return RetryConfig(
strategy="backoff",
backoff=BackoffStrategy(
initial_interval=500,
max_interval=30000,
exponent=1.5,
max_elapsed_time=300000,
),
retry_connection_errors=True,
)
async def __aenter__(self) -> MistralBackend:
self._client = self._create_mistral_client()
await self._client.__aenter__()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
if self._client is not None:
await self._client.__aexit__(
exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb
)
def _create_mistral_client(self) -> Mistral:
return Mistral(
api_key=self._api_key,
server_url=self._server_url,
timeout_ms=int(self._timeout * 1000),
retry_config=self._retry_config,
)
def _get_client(self) -> Mistral:
if self._client is None:
self._client = self._create_mistral_client()
return self._client
async def complete(
self,
*,
model: ModelConfig,
messages: Sequence[LLMMessage],
temperature: float,
tools: list[AvailableTool] | None,
max_tokens: int | None,
tool_choice: StrToolChoice | AvailableTool | None,
extra_headers: dict[str, str] | None,
metadata: dict[str, str] | None = None,
) -> LLMChunk:
try:
merged_messages = merge_consecutive_user_messages(messages)
reasoning_effort = _THINKING_TO_REASONING_EFFORT.get(model.thinking)
if reasoning_effort is not None:
temperature = 1.0
response = await self._get_client().chat.complete_async(
model=model.name,
messages=[self._mapper.prepare_message(msg) for msg in merged_messages],
temperature=temperature,
tools=[self._mapper.prepare_tool(tool) for tool in tools]
if tools
else None,
max_tokens=max_tokens,
tool_choice=self._mapper.prepare_tool_choice(tool_choice)
if tool_choice
else None,
http_headers=extra_headers,
metadata=metadata,
stream=False,
reasoning_effort=reasoning_effort,
)
parsed = (
self._mapper.parse_content(response.choices[0].message.content)
if response.choices[0].message.content
else ParsedContent(content="", reasoning_content=None)
)
return LLMChunk(
message=LLMMessage(
role=Role.assistant,
content=parsed.content,
reasoning_content=parsed.reasoning_content,
tool_calls=self._mapper.parse_tool_calls(
response.choices[0].message.tool_calls
)
if response.choices[0].message.tool_calls
else None,
),
usage=LLMUsage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
),
)
except SDKError as e:
raise BackendErrorBuilder.build_http_error(
provider=self._provider.name,
endpoint=self._server_url,
error=e,
model=model.name,
messages=messages,
temperature=temperature,
has_tools=bool(tools),
tool_choice=tool_choice,
) from e
except httpx.RequestError as e:
raise BackendErrorBuilder.build_request_error(
provider=self._provider.name,
endpoint=self._server_url,
error=e,
model=model.name,
messages=messages,
temperature=temperature,
has_tools=bool(tools),
tool_choice=tool_choice,
) from e
async def complete_streaming(
self,
*,
model: ModelConfig,
messages: Sequence[LLMMessage],
temperature: float,
tools: list[AvailableTool] | None,
max_tokens: int | None,
tool_choice: StrToolChoice | AvailableTool | None,
extra_headers: dict[str, str] | None,
metadata: dict[str, str] | None = None,
) -> AsyncGenerator[LLMChunk, None]:
try:
merged_messages = merge_consecutive_user_messages(messages)
reasoning_effort = _THINKING_TO_REASONING_EFFORT.get(model.thinking)
if reasoning_effort is not None:
temperature = 1.0
stream = await self._get_client().chat.stream_async(
model=model.name,
messages=[self._mapper.prepare_message(msg) for msg in merged_messages],
temperature=temperature,
tools=[self._mapper.prepare_tool(tool) for tool in tools]
if tools
else None,
max_tokens=max_tokens,
tool_choice=self._mapper.prepare_tool_choice(tool_choice)
if tool_choice
else None,
http_headers=extra_headers,
metadata=metadata,
reasoning_effort=reasoning_effort,
)
correlation_id = stream.response.headers.get("mistral-correlation-id")
async for chunk in stream:
parsed = (
self._mapper.parse_content(chunk.data.choices[0].delta.content)
if chunk.data.choices[0].delta.content
else ParsedContent(content="", reasoning_content=None)
)
yield LLMChunk(
message=LLMMessage(
role=Role.assistant,
content=parsed.content,
reasoning_content=parsed.reasoning_content,
tool_calls=self._mapper.parse_tool_calls(
chunk.data.choices[0].delta.tool_calls
)
if chunk.data.choices[0].delta.tool_calls
else None,
),
usage=LLMUsage(
prompt_tokens=chunk.data.usage.prompt_tokens or 0
if chunk.data.usage
else 0,
completion_tokens=chunk.data.usage.completion_tokens or 0
if chunk.data.usage
else 0,
),
correlation_id=correlation_id,
)
except SDKError as e:
raise BackendErrorBuilder.build_http_error(
provider=self._provider.name,
endpoint=self._server_url,
error=e,
model=model.name,
messages=messages,
temperature=temperature,
has_tools=bool(tools),
tool_choice=tool_choice,
) from e
except httpx.RequestError as e:
raise BackendErrorBuilder.build_request_error(
provider=self._provider.name,
endpoint=self._server_url,
error=e,
model=model.name,
messages=messages,
temperature=temperature,
has_tools=bool(tools),
tool_choice=tool_choice,
) from e
async def count_tokens(
self,
*,
model: ModelConfig,
messages: Sequence[LLMMessage],
temperature: float = 0.0,
tools: list[AvailableTool] | None = None,
tool_choice: StrToolChoice | AvailableTool | None = None,
extra_headers: dict[str, str] | None = None,
metadata: dict[str, str] | None = None,
) -> int:
result = await self.complete(
model=model,
messages=messages,
temperature=temperature,
tools=tools,
max_tokens=1,
tool_choice=tool_choice,
extra_headers=extra_headers,
metadata=metadata,
)
if result.usage is None:
raise ValueError("Missing usage in non streaming completion")
return result.usage.prompt_tokens