mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai> Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai> Co-authored-by: Lucas Marandat <31749711+lucasmrdt@users.noreply.github.com> Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com> Co-authored-by: Paul Cacheux <paul.cacheux@mistral.ai> Co-authored-by: Peter Evers <pevers90@gmail.com> Co-authored-by: Pierre Rossinès <pierre.rossines@mistral.ai> Co-authored-by: Pierre Rossinès <pierre.rossines@protonmail.com> Co-authored-by: Quentin <quentin.torroba@mistral.ai> Co-authored-by: Simon Van de Kerckhove <simon.vandekerckhove@mistral.ai> Co-authored-by: Val <102326092+vdeva@users.noreply.github.com> Co-authored-by: Vincent G <10739306+VinceOPS@users.noreply.github.com> Co-authored-by: Mistral Vibe <vibe@mistral.ai>
441 lines
15 KiB
Python
441 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator, Sequence
|
|
import json
|
|
import os
|
|
import types
|
|
from typing import TYPE_CHECKING, Literal, NamedTuple, cast
|
|
|
|
import httpx
|
|
from mistralai.client import Mistral
|
|
from mistralai.client.errors import SDKError
|
|
from mistralai.client.models import (
|
|
AssistantMessage,
|
|
AssistantMessageContent,
|
|
ChatCompletionRequestMessage,
|
|
ChatCompletionStreamRequestToolChoice,
|
|
ContentChunk,
|
|
FileChunk,
|
|
Function,
|
|
FunctionCall as MistralFunctionCall,
|
|
FunctionName,
|
|
SystemMessage,
|
|
TextChunk,
|
|
ThinkChunk,
|
|
Tool,
|
|
ToolCall as MistralToolCall,
|
|
ToolChoice,
|
|
ToolChoiceEnum,
|
|
ToolMessage,
|
|
UserMessage,
|
|
)
|
|
from mistralai.client.utils.retries import BackoffStrategy, RetryConfig
|
|
|
|
from vibe.core.llm.exceptions import BackendErrorBuilder
|
|
from vibe.core.llm.message_utils import merge_consecutive_user_messages
|
|
from vibe.core.types import (
|
|
AvailableTool,
|
|
Content,
|
|
FunctionCall,
|
|
LLMChunk,
|
|
LLMMessage,
|
|
LLMUsage,
|
|
Role,
|
|
StrToolChoice,
|
|
ToolCall,
|
|
)
|
|
from vibe.core.utils import get_server_url_from_api_base
|
|
|
|
if TYPE_CHECKING:
|
|
from vibe.core.config import ModelConfig, ProviderConfig
|
|
|
|
|
|
class ParsedContent(NamedTuple):
|
|
content: Content
|
|
reasoning_content: Content | None
|
|
|
|
|
|
class MistralMapper:
|
|
def prepare_message(self, msg: LLMMessage) -> ChatCompletionRequestMessage:
|
|
match msg.role:
|
|
case Role.system:
|
|
return SystemMessage(role="system", content=msg.content or "")
|
|
case Role.user:
|
|
return UserMessage(role="user", content=msg.content)
|
|
case Role.assistant:
|
|
content: AssistantMessageContent
|
|
if msg.reasoning_content:
|
|
chunks: list[ContentChunk] = [
|
|
ThinkChunk(
|
|
type="thinking",
|
|
thinking=[
|
|
TextChunk(type="text", text=msg.reasoning_content)
|
|
],
|
|
)
|
|
]
|
|
if msg.content:
|
|
chunks.append(TextChunk(type="text", text=msg.content))
|
|
content = chunks
|
|
else:
|
|
content = msg.content or ""
|
|
|
|
return AssistantMessage(
|
|
role="assistant",
|
|
content=content,
|
|
tool_calls=[
|
|
MistralToolCall(
|
|
function=MistralFunctionCall(
|
|
name=tc.function.name or "",
|
|
arguments=tc.function.arguments or "",
|
|
),
|
|
id=tc.id,
|
|
type=tc.type,
|
|
index=tc.index,
|
|
)
|
|
for tc in msg.tool_calls or []
|
|
],
|
|
)
|
|
case Role.tool:
|
|
return ToolMessage(
|
|
role="tool",
|
|
content=msg.content,
|
|
tool_call_id=msg.tool_call_id,
|
|
name=msg.name,
|
|
)
|
|
|
|
def prepare_tool(self, tool: AvailableTool) -> Tool:
|
|
return Tool(
|
|
type="function",
|
|
function=Function(
|
|
name=tool.function.name,
|
|
description=tool.function.description,
|
|
parameters=tool.function.parameters,
|
|
),
|
|
)
|
|
|
|
def prepare_tool_choice(
|
|
self, tool_choice: StrToolChoice | AvailableTool
|
|
) -> ChatCompletionStreamRequestToolChoice:
|
|
if isinstance(tool_choice, str):
|
|
return cast(ToolChoiceEnum, tool_choice)
|
|
|
|
return ToolChoice(
|
|
type="function", function=FunctionName(name=tool_choice.function.name)
|
|
)
|
|
|
|
def _extract_thinking_text(self, chunk: ThinkChunk) -> str:
|
|
thinking_content = getattr(chunk, "thinking", None)
|
|
if not thinking_content:
|
|
return ""
|
|
parts = []
|
|
for inner in thinking_content:
|
|
if hasattr(inner, "type") and inner.type == "text":
|
|
parts.append(getattr(inner, "text", ""))
|
|
elif isinstance(inner, str):
|
|
parts.append(inner)
|
|
return "".join(parts)
|
|
|
|
def parse_content(self, content: AssistantMessageContent) -> ParsedContent:
|
|
if isinstance(content, str):
|
|
return ParsedContent(content=content, reasoning_content=None)
|
|
|
|
concat_content = ""
|
|
concat_reasoning = ""
|
|
for chunk in content:
|
|
if isinstance(chunk, FileChunk):
|
|
continue
|
|
if isinstance(chunk, TextChunk):
|
|
concat_content += chunk.text
|
|
elif isinstance(chunk, ThinkChunk):
|
|
concat_reasoning += self._extract_thinking_text(chunk)
|
|
return ParsedContent(
|
|
content=concat_content,
|
|
reasoning_content=concat_reasoning if concat_reasoning else None,
|
|
)
|
|
|
|
def parse_tool_calls(self, tool_calls: list[MistralToolCall]) -> list[ToolCall]:
|
|
return [
|
|
ToolCall(
|
|
id=tool_call.id,
|
|
function=FunctionCall(
|
|
name=tool_call.function.name,
|
|
arguments=tool_call.function.arguments
|
|
if isinstance(tool_call.function.arguments, str)
|
|
else json.dumps(tool_call.function.arguments, ensure_ascii=False),
|
|
),
|
|
index=tool_call.index,
|
|
)
|
|
for tool_call in tool_calls
|
|
]
|
|
|
|
|
|
ReasoningEffortValue = Literal["none", "high"]
|
|
|
|
_THINKING_TO_REASONING_EFFORT: dict[str, ReasoningEffortValue] = {
|
|
"low": "none",
|
|
"medium": "high",
|
|
"high": "high",
|
|
}
|
|
|
|
|
|
class MistralBackend:
|
|
def __init__(self, provider: ProviderConfig, timeout: float = 720.0) -> None:
|
|
self._client: Mistral | None = None
|
|
self._provider = provider
|
|
self._mapper = MistralMapper()
|
|
self._api_key = (
|
|
os.getenv(self._provider.api_key_env_var)
|
|
if self._provider.api_key_env_var
|
|
else None
|
|
)
|
|
|
|
reasoning_field = getattr(provider, "reasoning_field_name", "reasoning_content")
|
|
if reasoning_field != "reasoning_content":
|
|
raise ValueError(
|
|
f"Mistral backend does not support custom reasoning_field_name "
|
|
f"(got '{reasoning_field}'). Mistral uses ThinkChunk for reasoning."
|
|
)
|
|
|
|
# Mistral SDK takes server URL without api version as input
|
|
server_url = get_server_url_from_api_base(self._provider.api_base)
|
|
if not server_url:
|
|
raise ValueError(
|
|
f"Invalid API base URL: {self._provider.api_base}. "
|
|
"Expected format: <server_url>/v<api_version>"
|
|
)
|
|
self._server_url = server_url
|
|
self._timeout = timeout
|
|
self._retry_config = self._build_retry_config()
|
|
|
|
def _build_retry_config(self) -> RetryConfig:
|
|
return RetryConfig(
|
|
strategy="backoff",
|
|
backoff=BackoffStrategy(
|
|
initial_interval=500,
|
|
max_interval=30000,
|
|
exponent=1.5,
|
|
max_elapsed_time=300000,
|
|
),
|
|
retry_connection_errors=True,
|
|
)
|
|
|
|
async def __aenter__(self) -> MistralBackend:
|
|
self._client = self._create_mistral_client()
|
|
await self._client.__aenter__()
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: types.TracebackType | None,
|
|
) -> None:
|
|
if self._client is not None:
|
|
await self._client.__aexit__(
|
|
exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb
|
|
)
|
|
|
|
def _create_mistral_client(self) -> Mistral:
|
|
return Mistral(
|
|
api_key=self._api_key,
|
|
server_url=self._server_url,
|
|
timeout_ms=int(self._timeout * 1000),
|
|
retry_config=self._retry_config,
|
|
)
|
|
|
|
def _get_client(self) -> Mistral:
|
|
if self._client is None:
|
|
self._client = self._create_mistral_client()
|
|
return self._client
|
|
|
|
async def complete(
|
|
self,
|
|
*,
|
|
model: ModelConfig,
|
|
messages: Sequence[LLMMessage],
|
|
temperature: float,
|
|
tools: list[AvailableTool] | None,
|
|
max_tokens: int | None,
|
|
tool_choice: StrToolChoice | AvailableTool | None,
|
|
extra_headers: dict[str, str] | None,
|
|
metadata: dict[str, str] | None = None,
|
|
) -> LLMChunk:
|
|
try:
|
|
merged_messages = merge_consecutive_user_messages(messages)
|
|
reasoning_effort = _THINKING_TO_REASONING_EFFORT.get(model.thinking)
|
|
if reasoning_effort is not None:
|
|
temperature = 1.0
|
|
|
|
response = await self._get_client().chat.complete_async(
|
|
model=model.name,
|
|
messages=[self._mapper.prepare_message(msg) for msg in merged_messages],
|
|
temperature=temperature,
|
|
tools=[self._mapper.prepare_tool(tool) for tool in tools]
|
|
if tools
|
|
else None,
|
|
max_tokens=max_tokens,
|
|
tool_choice=self._mapper.prepare_tool_choice(tool_choice)
|
|
if tool_choice
|
|
else None,
|
|
http_headers=extra_headers,
|
|
metadata=metadata,
|
|
stream=False,
|
|
reasoning_effort=reasoning_effort,
|
|
)
|
|
|
|
parsed = (
|
|
self._mapper.parse_content(response.choices[0].message.content)
|
|
if response.choices[0].message.content
|
|
else ParsedContent(content="", reasoning_content=None)
|
|
)
|
|
return LLMChunk(
|
|
message=LLMMessage(
|
|
role=Role.assistant,
|
|
content=parsed.content,
|
|
reasoning_content=parsed.reasoning_content,
|
|
tool_calls=self._mapper.parse_tool_calls(
|
|
response.choices[0].message.tool_calls
|
|
)
|
|
if response.choices[0].message.tool_calls
|
|
else None,
|
|
),
|
|
usage=LLMUsage(
|
|
prompt_tokens=response.usage.prompt_tokens or 0,
|
|
completion_tokens=response.usage.completion_tokens or 0,
|
|
),
|
|
)
|
|
|
|
except SDKError as e:
|
|
raise BackendErrorBuilder.build_http_error(
|
|
provider=self._provider.name,
|
|
endpoint=self._server_url,
|
|
error=e,
|
|
model=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
has_tools=bool(tools),
|
|
tool_choice=tool_choice,
|
|
) from e
|
|
except httpx.RequestError as e:
|
|
raise BackendErrorBuilder.build_request_error(
|
|
provider=self._provider.name,
|
|
endpoint=self._server_url,
|
|
error=e,
|
|
model=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
has_tools=bool(tools),
|
|
tool_choice=tool_choice,
|
|
) from e
|
|
|
|
async def complete_streaming(
|
|
self,
|
|
*,
|
|
model: ModelConfig,
|
|
messages: Sequence[LLMMessage],
|
|
temperature: float,
|
|
tools: list[AvailableTool] | None,
|
|
max_tokens: int | None,
|
|
tool_choice: StrToolChoice | AvailableTool | None,
|
|
extra_headers: dict[str, str] | None,
|
|
metadata: dict[str, str] | None = None,
|
|
) -> AsyncGenerator[LLMChunk, None]:
|
|
try:
|
|
merged_messages = merge_consecutive_user_messages(messages)
|
|
reasoning_effort = _THINKING_TO_REASONING_EFFORT.get(model.thinking)
|
|
if reasoning_effort is not None:
|
|
temperature = 1.0
|
|
|
|
stream = await self._get_client().chat.stream_async(
|
|
model=model.name,
|
|
messages=[self._mapper.prepare_message(msg) for msg in merged_messages],
|
|
temperature=temperature,
|
|
tools=[self._mapper.prepare_tool(tool) for tool in tools]
|
|
if tools
|
|
else None,
|
|
max_tokens=max_tokens,
|
|
tool_choice=self._mapper.prepare_tool_choice(tool_choice)
|
|
if tool_choice
|
|
else None,
|
|
http_headers=extra_headers,
|
|
metadata=metadata,
|
|
reasoning_effort=reasoning_effort,
|
|
)
|
|
correlation_id = stream.response.headers.get("mistral-correlation-id")
|
|
async for chunk in stream:
|
|
parsed = (
|
|
self._mapper.parse_content(chunk.data.choices[0].delta.content)
|
|
if chunk.data.choices[0].delta.content
|
|
else ParsedContent(content="", reasoning_content=None)
|
|
)
|
|
yield LLMChunk(
|
|
message=LLMMessage(
|
|
role=Role.assistant,
|
|
content=parsed.content,
|
|
reasoning_content=parsed.reasoning_content,
|
|
tool_calls=self._mapper.parse_tool_calls(
|
|
chunk.data.choices[0].delta.tool_calls
|
|
)
|
|
if chunk.data.choices[0].delta.tool_calls
|
|
else None,
|
|
),
|
|
usage=LLMUsage(
|
|
prompt_tokens=chunk.data.usage.prompt_tokens or 0
|
|
if chunk.data.usage
|
|
else 0,
|
|
completion_tokens=chunk.data.usage.completion_tokens or 0
|
|
if chunk.data.usage
|
|
else 0,
|
|
),
|
|
correlation_id=correlation_id,
|
|
)
|
|
|
|
except SDKError as e:
|
|
raise BackendErrorBuilder.build_http_error(
|
|
provider=self._provider.name,
|
|
endpoint=self._server_url,
|
|
error=e,
|
|
model=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
has_tools=bool(tools),
|
|
tool_choice=tool_choice,
|
|
) from e
|
|
except httpx.RequestError as e:
|
|
raise BackendErrorBuilder.build_request_error(
|
|
provider=self._provider.name,
|
|
endpoint=self._server_url,
|
|
error=e,
|
|
model=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
has_tools=bool(tools),
|
|
tool_choice=tool_choice,
|
|
) from e
|
|
|
|
async def count_tokens(
|
|
self,
|
|
*,
|
|
model: ModelConfig,
|
|
messages: Sequence[LLMMessage],
|
|
temperature: float = 0.0,
|
|
tools: list[AvailableTool] | None = None,
|
|
tool_choice: StrToolChoice | AvailableTool | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
metadata: dict[str, str] | None = None,
|
|
) -> int:
|
|
result = await self.complete(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
tools=tools,
|
|
max_tokens=1,
|
|
tool_choice=tool_choice,
|
|
extra_headers=extra_headers,
|
|
metadata=metadata,
|
|
)
|
|
if result.usage is None:
|
|
raise ValueError("Missing usage in non streaming completion")
|
|
|
|
return result.usage.prompt_tokens
|