mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Quentin Torroba <quentin.torroba@mistral.ai> Co-authored-by: Clément Drouin <clement.drouin@mistral.ai> Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai> Co-authored-by: Mistral Vibe <vibe@mistral.ai>
429 lines
14 KiB
Python
429 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator, Sequence
|
|
import json
|
|
import os
|
|
import types
|
|
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple
|
|
|
|
import httpx
|
|
|
|
from vibe.core.llm.backend.anthropic import AnthropicAdapter
|
|
from vibe.core.llm.backend.base import APIAdapter, PreparedRequest
|
|
from vibe.core.llm.backend.reasoning_adapter import ReasoningAdapter
|
|
from vibe.core.llm.backend.vertex import VertexAnthropicAdapter
|
|
from vibe.core.llm.exceptions import BackendErrorBuilder
|
|
from vibe.core.llm.message_utils import merge_consecutive_user_messages
|
|
from vibe.core.types import (
|
|
AvailableTool,
|
|
LLMChunk,
|
|
LLMMessage,
|
|
LLMUsage,
|
|
Role,
|
|
StrToolChoice,
|
|
)
|
|
from vibe.core.utils import async_generator_retry, async_retry
|
|
|
|
if TYPE_CHECKING:
|
|
from vibe.core.config import ModelConfig, ProviderConfig
|
|
|
|
|
|
class OpenAIAdapter(APIAdapter):
|
|
endpoint: ClassVar[str] = "/chat/completions"
|
|
|
|
def build_payload(
|
|
self,
|
|
model_name: str,
|
|
converted_messages: list[dict[str, Any]],
|
|
temperature: float,
|
|
tools: list[AvailableTool] | None,
|
|
max_tokens: int | None,
|
|
tool_choice: StrToolChoice | AvailableTool | None,
|
|
) -> dict[str, Any]:
|
|
payload = {
|
|
"model": model_name,
|
|
"messages": converted_messages,
|
|
"temperature": temperature,
|
|
}
|
|
|
|
if tools:
|
|
payload["tools"] = [tool.model_dump(exclude_none=True) for tool in tools]
|
|
if tool_choice:
|
|
payload["tool_choice"] = (
|
|
tool_choice
|
|
if isinstance(tool_choice, str)
|
|
else tool_choice.model_dump()
|
|
)
|
|
if max_tokens is not None:
|
|
payload["max_tokens"] = max_tokens
|
|
|
|
return payload
|
|
|
|
def build_headers(self, api_key: str | None = None) -> dict[str, str]:
|
|
headers = {"Content-Type": "application/json"}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
return headers
|
|
|
|
def _reasoning_to_api(
|
|
self, msg_dict: dict[str, Any], field_name: str
|
|
) -> dict[str, Any]:
|
|
if field_name != "reasoning_content" and "reasoning_content" in msg_dict:
|
|
msg_dict[field_name] = msg_dict.pop("reasoning_content")
|
|
return msg_dict
|
|
|
|
def _reasoning_from_api(
|
|
self, msg_dict: dict[str, Any], field_name: str
|
|
) -> dict[str, Any]:
|
|
if field_name != "reasoning_content" and field_name in msg_dict:
|
|
msg_dict["reasoning_content"] = msg_dict.pop(field_name)
|
|
return msg_dict
|
|
|
|
def prepare_request( # noqa: PLR0913
|
|
self,
|
|
*,
|
|
model_name: str,
|
|
messages: Sequence[LLMMessage],
|
|
temperature: float,
|
|
tools: list[AvailableTool] | None,
|
|
max_tokens: int | None,
|
|
tool_choice: StrToolChoice | AvailableTool | None,
|
|
enable_streaming: bool,
|
|
provider: ProviderConfig,
|
|
api_key: str | None = None,
|
|
thinking: str = "off",
|
|
) -> PreparedRequest:
|
|
merged_messages = merge_consecutive_user_messages(messages)
|
|
field_name = provider.reasoning_field_name
|
|
converted_messages = [
|
|
self._reasoning_to_api(
|
|
msg.model_dump(exclude_none=True, exclude={"message_id"}), field_name
|
|
)
|
|
for msg in merged_messages
|
|
]
|
|
|
|
payload = self.build_payload(
|
|
model_name, converted_messages, temperature, tools, max_tokens, tool_choice
|
|
)
|
|
|
|
if enable_streaming:
|
|
payload["stream"] = True
|
|
stream_options = {"include_usage": True}
|
|
if provider.name == "mistral":
|
|
stream_options["stream_tool_calls"] = True
|
|
payload["stream_options"] = stream_options
|
|
|
|
headers = self.build_headers(api_key)
|
|
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
|
|
return PreparedRequest(self.endpoint, headers, body)
|
|
|
|
def _parse_message(
|
|
self, data: dict[str, Any], field_name: str
|
|
) -> LLMMessage | None:
|
|
if data.get("choices"):
|
|
choice = data["choices"][0]
|
|
if "message" in choice:
|
|
msg_dict = self._reasoning_from_api(choice["message"], field_name)
|
|
return LLMMessage.model_validate(msg_dict)
|
|
if "delta" in choice:
|
|
msg_dict = self._reasoning_from_api(choice["delta"], field_name)
|
|
return LLMMessage.model_validate(msg_dict)
|
|
raise ValueError("Invalid response data: missing message or delta")
|
|
|
|
if "message" in data:
|
|
msg_dict = self._reasoning_from_api(data["message"], field_name)
|
|
return LLMMessage.model_validate(msg_dict)
|
|
if "delta" in data:
|
|
msg_dict = self._reasoning_from_api(data["delta"], field_name)
|
|
return LLMMessage.model_validate(msg_dict)
|
|
|
|
return None
|
|
|
|
def parse_response(
|
|
self, data: dict[str, Any], provider: ProviderConfig
|
|
) -> LLMChunk:
|
|
message = self._parse_message(data, provider.reasoning_field_name)
|
|
if message is None:
|
|
message = LLMMessage(role=Role.assistant, content="")
|
|
|
|
usage_data = data.get("usage") or {}
|
|
usage = LLMUsage(
|
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
)
|
|
|
|
return LLMChunk(message=message, usage=usage)
|
|
|
|
|
|
ADAPTERS: dict[str, APIAdapter] = {
|
|
"openai": OpenAIAdapter(),
|
|
"anthropic": AnthropicAdapter(),
|
|
"vertex-anthropic": VertexAnthropicAdapter(),
|
|
"reasoning": ReasoningAdapter(),
|
|
}
|
|
|
|
|
|
class GenericBackend:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
client: httpx.AsyncClient | None = None,
|
|
provider: ProviderConfig,
|
|
timeout: float = 720.0,
|
|
) -> None:
|
|
"""Initialize the backend.
|
|
|
|
Args:
|
|
client: Optional httpx client to use. If not provided, one will be created.
|
|
"""
|
|
self._client = client
|
|
self._owns_client = client is None
|
|
self._provider = provider
|
|
self._timeout = timeout
|
|
|
|
async def __aenter__(self) -> GenericBackend:
|
|
if self._client is None:
|
|
self._client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(self._timeout),
|
|
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
|
)
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: types.TracebackType | None,
|
|
) -> None:
|
|
if self._owns_client and self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
def _get_client(self) -> httpx.AsyncClient:
|
|
if self._client is None:
|
|
self._client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(self._timeout),
|
|
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
|
)
|
|
self._owns_client = True
|
|
return self._client
|
|
|
|
async def complete(
|
|
self,
|
|
*,
|
|
model: ModelConfig,
|
|
messages: Sequence[LLMMessage],
|
|
temperature: float = 0.2,
|
|
tools: list[AvailableTool] | None = None,
|
|
max_tokens: int | None = None,
|
|
tool_choice: StrToolChoice | AvailableTool | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
metadata: dict[str, str] | None = None,
|
|
) -> LLMChunk:
|
|
api_key = (
|
|
os.getenv(self._provider.api_key_env_var)
|
|
if self._provider.api_key_env_var
|
|
else None
|
|
)
|
|
|
|
api_style = getattr(self._provider, "api_style", "openai")
|
|
adapter = ADAPTERS[api_style]
|
|
|
|
req = adapter.prepare_request(
|
|
model_name=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
tools=tools,
|
|
max_tokens=max_tokens,
|
|
tool_choice=tool_choice,
|
|
enable_streaming=False,
|
|
provider=self._provider,
|
|
api_key=api_key,
|
|
thinking=model.thinking,
|
|
)
|
|
|
|
headers = req.headers
|
|
if extra_headers:
|
|
headers.update(extra_headers)
|
|
|
|
base = req.base_url or self._provider.api_base
|
|
url = f"{base}{req.endpoint}"
|
|
|
|
try:
|
|
res_data, _ = await self._make_request(url, req.body, headers)
|
|
return adapter.parse_response(res_data, self._provider)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
raise BackendErrorBuilder.build_http_error(
|
|
provider=self._provider.name,
|
|
endpoint=url,
|
|
error=e,
|
|
model=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
has_tools=bool(tools),
|
|
tool_choice=tool_choice,
|
|
) from e
|
|
except httpx.RequestError as e:
|
|
raise BackendErrorBuilder.build_request_error(
|
|
provider=self._provider.name,
|
|
endpoint=url,
|
|
error=e,
|
|
model=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
has_tools=bool(tools),
|
|
tool_choice=tool_choice,
|
|
) from e
|
|
|
|
async def complete_streaming(
|
|
self,
|
|
*,
|
|
model: ModelConfig,
|
|
messages: Sequence[LLMMessage],
|
|
temperature: float = 0.2,
|
|
tools: list[AvailableTool] | None = None,
|
|
max_tokens: int | None = None,
|
|
tool_choice: StrToolChoice | AvailableTool | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
metadata: dict[str, str] | None = None,
|
|
) -> AsyncGenerator[LLMChunk, None]:
|
|
api_key = (
|
|
os.getenv(self._provider.api_key_env_var)
|
|
if self._provider.api_key_env_var
|
|
else None
|
|
)
|
|
|
|
api_style = getattr(self._provider, "api_style", "openai")
|
|
adapter = ADAPTERS[api_style]
|
|
|
|
req = adapter.prepare_request(
|
|
model_name=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
tools=tools,
|
|
max_tokens=max_tokens,
|
|
tool_choice=tool_choice,
|
|
enable_streaming=True,
|
|
provider=self._provider,
|
|
api_key=api_key,
|
|
thinking=model.thinking,
|
|
)
|
|
|
|
headers = req.headers
|
|
if extra_headers:
|
|
headers.update(extra_headers)
|
|
|
|
base = req.base_url or self._provider.api_base
|
|
url = f"{base}{req.endpoint}"
|
|
|
|
try:
|
|
async for res_data in self._make_streaming_request(url, req.body, headers):
|
|
yield adapter.parse_response(res_data, self._provider)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
raise BackendErrorBuilder.build_http_error(
|
|
provider=self._provider.name,
|
|
endpoint=url,
|
|
error=e,
|
|
model=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
has_tools=bool(tools),
|
|
tool_choice=tool_choice,
|
|
) from e
|
|
except httpx.RequestError as e:
|
|
raise BackendErrorBuilder.build_request_error(
|
|
provider=self._provider.name,
|
|
endpoint=url,
|
|
error=e,
|
|
model=model.name,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
has_tools=bool(tools),
|
|
tool_choice=tool_choice,
|
|
) from e
|
|
|
|
class HTTPResponse(NamedTuple):
|
|
data: dict[str, Any]
|
|
headers: dict[str, str]
|
|
|
|
@async_retry(tries=3)
|
|
async def _make_request(
|
|
self, url: str, data: bytes, headers: dict[str, str]
|
|
) -> HTTPResponse:
|
|
client = self._get_client()
|
|
response = await client.post(url, content=data, headers=headers)
|
|
response.raise_for_status()
|
|
|
|
response_headers = dict(response.headers.items())
|
|
response_body = response.json()
|
|
return self.HTTPResponse(response_body, response_headers)
|
|
|
|
@async_generator_retry(tries=3)
|
|
async def _make_streaming_request(
|
|
self, url: str, data: bytes, headers: dict[str, str]
|
|
) -> AsyncGenerator[dict[str, Any]]:
|
|
client = self._get_client()
|
|
async with client.stream(
|
|
method="POST", url=url, content=data, headers=headers
|
|
) as response:
|
|
if not response.is_success:
|
|
await response.aread()
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if line.strip() == "":
|
|
continue
|
|
|
|
DELIM_CHAR = ":"
|
|
if f"{DELIM_CHAR} " not in line:
|
|
raise ValueError(
|
|
f"Stream chunk improperly formatted. "
|
|
f"Expected `key{DELIM_CHAR} value`, received `{line}`"
|
|
)
|
|
delim_index = line.find(DELIM_CHAR)
|
|
key = line[0:delim_index]
|
|
value = line[delim_index + 2 :]
|
|
|
|
if key != "data":
|
|
# This might be the case with openrouter, so we just ignore it
|
|
continue
|
|
if value == "[DONE]":
|
|
return
|
|
yield json.loads(value.strip())
|
|
|
|
async def count_tokens(
|
|
self,
|
|
*,
|
|
model: ModelConfig,
|
|
messages: Sequence[LLMMessage],
|
|
temperature: float = 0.0,
|
|
tools: list[AvailableTool] | None = None,
|
|
tool_choice: StrToolChoice | AvailableTool | None = None,
|
|
extra_headers: dict[str, str] | None = None,
|
|
metadata: dict[str, str] | None = None,
|
|
) -> int:
|
|
probe_messages = list(messages)
|
|
if not probe_messages or probe_messages[-1].role != Role.user:
|
|
probe_messages.append(LLMMessage(role=Role.user, content=""))
|
|
|
|
result = await self.complete(
|
|
model=model,
|
|
messages=probe_messages,
|
|
temperature=temperature,
|
|
tools=tools,
|
|
max_tokens=16, # Minimal amount for openrouter with openai models
|
|
tool_choice=tool_choice,
|
|
extra_headers=extra_headers,
|
|
)
|
|
if result.usage is None:
|
|
raise ValueError("Missing usage in non streaming completion")
|
|
|
|
return result.usage.prompt_tokens
|
|
|
|
async def close(self) -> None:
|
|
if self._owns_client and self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|