mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Quentin Torroba <quentin.torroba@mistral.ai> Co-authored-by: Clément Siriex <clement.sirieix@mistral.ai> Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai> Co-authored-by: Michel Thomazo <michel.thomazo@mistral.ai> Co-authored-by: Clément Drouin <clement.drouin@mistral.ai>
185 lines
5.5 KiB
Python
185 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
|
|
|
from vibe.core.tools.base import BaseTool
|
|
from vibe.core.types import (
|
|
AvailableFunction,
|
|
AvailableTool,
|
|
LLMMessage,
|
|
Role,
|
|
StrToolChoice,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from vibe.core.tools.manager import ToolManager
|
|
|
|
|
|
class ParsedToolCall(BaseModel):
|
|
model_config = ConfigDict(frozen=True)
|
|
tool_name: str
|
|
raw_args: dict[str, Any]
|
|
call_id: str = ""
|
|
|
|
|
|
class ResolvedToolCall(BaseModel):
|
|
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
|
tool_name: str
|
|
tool_class: type[BaseTool]
|
|
validated_args: BaseModel
|
|
call_id: str = ""
|
|
|
|
@property
|
|
def args_dict(self) -> dict[str, Any]:
|
|
return self.validated_args.model_dump()
|
|
|
|
|
|
class FailedToolCall(BaseModel):
|
|
model_config = ConfigDict(frozen=True)
|
|
tool_name: str
|
|
call_id: str
|
|
error: str
|
|
|
|
|
|
class ParsedMessage(BaseModel):
|
|
model_config = ConfigDict(frozen=True)
|
|
tool_calls: list[ParsedToolCall]
|
|
|
|
|
|
class ResolvedMessage(BaseModel):
|
|
model_config = ConfigDict(frozen=True)
|
|
tool_calls: list[ResolvedToolCall]
|
|
failed_calls: list[FailedToolCall] = Field(default_factory=list)
|
|
|
|
|
|
class APIToolFormatHandler:
|
|
@property
|
|
def name(self) -> str:
|
|
return "api"
|
|
|
|
def get_available_tools(self, tool_manager: ToolManager) -> list[AvailableTool]:
|
|
return [
|
|
AvailableTool(
|
|
function=AvailableFunction(
|
|
name=tool_class.get_name(),
|
|
description=tool_class.description,
|
|
parameters=tool_class.get_parameters(),
|
|
)
|
|
)
|
|
for tool_class in tool_manager.available_tools.values()
|
|
]
|
|
|
|
def get_tool_choice(self) -> StrToolChoice | AvailableTool:
|
|
return "auto"
|
|
|
|
def process_api_response_message(self, message: Any) -> LLMMessage:
|
|
clean_message = {
|
|
"role": message.role,
|
|
"content": message.content,
|
|
"reasoning_content": getattr(message, "reasoning_content", None),
|
|
"reasoning_signature": getattr(message, "reasoning_signature", None),
|
|
}
|
|
|
|
if message.tool_calls:
|
|
clean_message["tool_calls"] = [
|
|
{
|
|
"id": tc.id,
|
|
"index": tc.index,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in message.tool_calls
|
|
]
|
|
|
|
return LLMMessage.model_validate(clean_message)
|
|
|
|
def parse_message(self, message: LLMMessage) -> ParsedMessage:
|
|
tool_calls = []
|
|
|
|
api_tool_calls = message.tool_calls or []
|
|
for tc in api_tool_calls:
|
|
if not (function_call := tc.function):
|
|
continue
|
|
try:
|
|
args = json.loads(function_call.arguments or "{}")
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
|
|
tool_calls.append(
|
|
ParsedToolCall(
|
|
tool_name=function_call.name or "",
|
|
raw_args=args,
|
|
call_id=tc.id or "",
|
|
)
|
|
)
|
|
|
|
return ParsedMessage(tool_calls=tool_calls)
|
|
|
|
def resolve_tool_calls(
|
|
self, parsed: ParsedMessage, tool_manager: ToolManager
|
|
) -> ResolvedMessage:
|
|
resolved_calls = []
|
|
failed_calls = []
|
|
|
|
active_tools = tool_manager.available_tools
|
|
|
|
for parsed_call in parsed.tool_calls:
|
|
tool_class = active_tools.get(parsed_call.tool_name)
|
|
if not tool_class:
|
|
failed_calls.append(
|
|
FailedToolCall(
|
|
tool_name=parsed_call.tool_name,
|
|
call_id=parsed_call.call_id,
|
|
error=f"Unknown tool '{parsed_call.tool_name}'",
|
|
)
|
|
)
|
|
continue
|
|
|
|
args_model, _ = tool_class._get_tool_args_results()
|
|
try:
|
|
validated_args = args_model.model_validate(parsed_call.raw_args)
|
|
resolved_calls.append(
|
|
ResolvedToolCall(
|
|
tool_name=parsed_call.tool_name,
|
|
tool_class=tool_class,
|
|
validated_args=validated_args,
|
|
call_id=parsed_call.call_id,
|
|
)
|
|
)
|
|
except ValidationError as e:
|
|
failed_calls.append(
|
|
FailedToolCall(
|
|
tool_name=parsed_call.tool_name,
|
|
call_id=parsed_call.call_id,
|
|
error=f"Invalid arguments: {e}",
|
|
)
|
|
)
|
|
|
|
return ResolvedMessage(tool_calls=resolved_calls, failed_calls=failed_calls)
|
|
|
|
def create_tool_response_message(
|
|
self, tool_call: ResolvedToolCall, result_text: str
|
|
) -> LLMMessage:
|
|
return LLMMessage(
|
|
role=Role.tool,
|
|
tool_call_id=tool_call.call_id,
|
|
name=tool_call.tool_name,
|
|
content=result_text,
|
|
)
|
|
|
|
def create_failed_tool_response_message(
|
|
self, failed: FailedToolCall, error_content: str
|
|
) -> LLMMessage:
|
|
return LLMMessage(
|
|
role=Role.tool,
|
|
tool_call_id=failed.call_id,
|
|
name=failed.tool_name,
|
|
content=error_content,
|
|
)
|