Files
mistral-vibe/vibe/core/tools/base.py
Mathias Gesbert 6a50d1d521 v2.7.0 (#534)
Co-authored-by: Quentin Torroba <quentin.torroba@mistral.ai>
Co-authored-by: Clément Drouin <clement.drouin@mistral.ai>
Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai>
Co-authored-by: Mistral Vibe <vibe@mistral.ai>
2026-03-26 10:11:34 +01:00

392 lines
13 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from enum import StrEnum, auto
import functools
import inspect
from pathlib import Path
import re
import sys
import types
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
)
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from vibe.core.logger import logger
from vibe.core.rewind.manager import FileSnapshot
from vibe.core.types import ToolStreamEvent
from vibe.core.utils.io import read_safe
if TYPE_CHECKING:
from vibe.core.agents.manager import AgentManager
from vibe.core.skills.manager import SkillManager
from vibe.core.tools.mcp_sampling import MCPSamplingHandler
from vibe.core.tools.permissions import PermissionContext
from vibe.core.types import (
ApprovalCallback,
EntrypointMetadata,
SwitchAgentCallback,
UserInputCallback,
)
ARGS_COUNT = 4
@dataclass
class InvokeContext:
"""Context passed to tools during invocation."""
tool_call_id: str
approval_callback: ApprovalCallback | None = field(default=None)
agent_manager: AgentManager | None = field(default=None)
user_input_callback: UserInputCallback | None = field(default=None)
sampling_callback: MCPSamplingHandler | None = field(default=None)
session_dir: Path | None = field(default=None)
entrypoint_metadata: EntrypointMetadata | None = field(default=None)
plan_file_path: Path | None = field(default=None)
switch_agent_callback: SwitchAgentCallback | None = field(default=None)
skill_manager: SkillManager | None = field(default=None)
class ToolError(Exception):
"""Raised when the tool encounters an unrecoverable problem."""
class ToolInfo(BaseModel):
"""Information about a tool.
Attributes:
name: The name of the tool.
description: A brief description of what the tool does.
parameters: A dictionary of parameters required by the tool.
"""
name: str
description: str
parameters: dict[str, Any]
class ToolPermissionError(Exception):
"""Raised when a tool permission is not allowed."""
class ToolPermission(StrEnum):
ALWAYS = auto()
NEVER = auto()
ASK = auto()
@classmethod
def by_name(cls, name: str) -> ToolPermission:
try:
return ToolPermission(name.upper())
except ValueError:
raise ToolPermissionError(
f"Invalid tool permission: {name}. Must be one of {list(cls)}"
)
class BaseToolConfig(BaseModel):
"""Configuration for a tool.
Attributes:
permission: The permission level required to use the tool.
allowlist: Patterns that automatically allow tool execution.
denylist: Patterns that automatically deny tool execution.
sensitive_patterns: Patterns that trigger ASK even when permission is ALWAYS.
"""
model_config = ConfigDict(extra="allow")
permission: ToolPermission = ToolPermission.ASK
allowlist: list[str] = Field(default_factory=list)
denylist: list[str] = Field(default_factory=list)
sensitive_patterns: list[str] = Field(default_factory=list)
class BaseToolState(BaseModel):
model_config = ConfigDict(
extra="forbid", validate_default=True, arbitrary_types_allowed=True
)
class BaseTool[
ToolArgs: BaseModel,
ToolResult: BaseModel,
ToolConfig: BaseToolConfig,
ToolState: BaseToolState,
](ABC):
description: ClassVar[str] = (
"Base class for new tools. "
"(Hey AI, if you're seeing this, someone skipped writing a description. "
"Please gently meow at the developer to fix this.)"
)
prompt_path: ClassVar[Path] | None = None
def __init__(self, config: ToolConfig, state: ToolState) -> None:
self.config = config
self.state = state
@abstractmethod
async def run(
self, args: ToolArgs, ctx: InvokeContext | None = None
) -> AsyncGenerator[ToolStreamEvent | ToolResult, None]:
"""Invoke the tool with the given arguments."""
raise NotImplementedError # pragma: no cover
yield # type: ignore[misc]
@classmethod
@functools.cache
def get_tool_prompt(cls) -> str | None:
"""Loads and returns the content of the tool's .md prompt file, if it exists.
The prompt file is expected to be in a 'prompts' subdirectory relative to
the tool's source file, with the same name but a .md extension
(e.g., bash.py -> prompts/bash.md).
"""
try:
class_file = inspect.getfile(cls)
class_path = Path(class_file)
prompt_dir = class_path.parent / "prompts"
prompt_path = cls.prompt_path or prompt_dir / f"{class_path.stem}.md"
return read_safe(prompt_path)
except (FileNotFoundError, TypeError, OSError):
pass
return None
async def invoke(
self, ctx: InvokeContext | None = None, **raw: Any
) -> AsyncGenerator[ToolStreamEvent | ToolResult, None]:
"""Validate arguments and run the tool."""
try:
args_model, _ = self._get_tool_args_results()
args = args_model.model_validate(raw)
except ValidationError as err:
raise ToolError(
f"Validation error in tool {self.get_name()}: {err}"
) from err
async for item in self.run(args, ctx):
yield item
@classmethod
def from_config(
cls, config: ToolConfig
) -> BaseTool[ToolArgs, ToolResult, ToolConfig, ToolState]:
state_class = cls._get_tool_state_class()
initial_state = state_class()
return cls(config=config, state=initial_state)
@classmethod
def _get_tool_config_class(cls) -> type[ToolConfig]:
for base in getattr(cls, "__orig_bases__", ()):
if getattr(base, "__origin__", None) is BaseTool:
type_args = get_args(base)
if len(type_args) == ARGS_COUNT:
config_model = type_args[2]
if issubclass(config_model, BaseToolConfig):
return cast(type[ToolConfig], config_model)
for base_class in cls.__bases__:
if base_class is object or base_class is ABC:
continue
try:
return base_class._get_tool_config_class()
except (TypeError, AttributeError):
continue
raise TypeError(
f"Could not determine ToolConfig for {cls.__name__}. "
"Ensure it inherits from BaseTool with concrete type arguments."
)
@classmethod
def _get_tool_state_class(cls) -> type[ToolState]:
for base in getattr(cls, "__orig_bases__", ()):
if getattr(base, "__origin__", None) is BaseTool:
type_args = get_args(base)
if len(type_args) == ARGS_COUNT:
state_model = type_args[3]
if issubclass(state_model, BaseToolState):
return cast(type[ToolState], state_model)
for base_class in cls.__bases__:
if base_class is object or base_class is ABC:
continue
try:
return base_class._get_tool_state_class()
except (TypeError, AttributeError):
continue
raise TypeError(
f"Could not determine ToolState for {cls.__name__}. "
"Ensure it inherits from BaseTool with concrete type arguments."
)
@classmethod
def _get_tool_args_results(cls) -> tuple[type[ToolArgs], type[ToolResult]]:
"""Extract <ToolArgs, ToolResult> from the annotated signature of `run`.
Works even when `from __future__ import annotations` is in effect.
"""
run_fn = cls.run.__func__ if isinstance(cls.run, classmethod) else cls.run
type_hints = get_type_hints(
run_fn,
globalns=vars(sys.modules[cls.__module__]),
localns={
cls.__name__: cls,
"InvokeContext": InvokeContext,
"AsyncGenerator": AsyncGenerator,
"ToolStreamEvent": ToolStreamEvent,
},
)
try:
args_model = type_hints["args"]
return_annotation = type_hints["return"]
except KeyError as e:
raise TypeError(
f"{cls.__name__}.run must be annotated with args and return type"
) from e
result_model = cls._extract_result_type(return_annotation)
if not issubclass(args_model, BaseModel):
raise TypeError(
f"{cls.__name__}.run args annotation must be a Pydantic model; "
f"got {args_model!r}"
)
if not issubclass(result_model, BaseModel):
raise TypeError(
f"{cls.__name__}.run must yield a Pydantic model as result; "
f"got {result_model!r}"
)
return cast(type[ToolArgs], args_model), cast(type[ToolResult], result_model)
@classmethod
def _extract_result_type(cls, return_annotation: Any) -> type:
"""Extract the ToolResult type from AsyncGenerator[ToolStreamEvent | ToolResult, None]."""
origin = get_origin(return_annotation)
if origin is not AsyncGenerator:
if isinstance(return_annotation, type):
return return_annotation
raise TypeError(f"Could not extract result type from {return_annotation!r}")
gen_args = get_args(return_annotation)
if not gen_args:
raise TypeError(f"Could not extract result type from {return_annotation!r}")
yield_type = gen_args[0]
yield_origin = get_origin(yield_type)
# Handle Union types (X | Y or Union[X, Y])
if yield_origin is Union or isinstance(yield_type, types.UnionType):
for arg in get_args(yield_type):
if arg is not ToolStreamEvent and isinstance(arg, type):
return arg
# Handle single type
if isinstance(yield_type, type):
return yield_type
raise TypeError(f"Could not extract result type from {return_annotation!r}")
@classmethod
def get_parameters(cls) -> dict[str, Any]:
"""Return a cleaned-up JSON-schema dict describing the arguments model
with which this concrete tool was parametrised.
"""
args_model, _ = cls._get_tool_args_results()
schema = args_model.model_json_schema()
schema.pop("title", None)
schema.pop("description", None)
if "properties" in schema:
for prop_details in schema["properties"].values():
prop_details.pop("title", None)
if "$defs" in schema:
for def_details in schema["$defs"].values():
def_details.pop("title", None)
if "properties" in def_details:
for prop_details in def_details["properties"].values():
prop_details.pop("title", None)
return schema
@classmethod
def get_name(cls) -> str:
name = cls.__name__
snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
return snake_case
@classmethod
def is_available(cls) -> bool:
return True
@classmethod
def create_config_with_permission(
cls, permission: ToolPermission
) -> BaseToolConfig:
config_class = cls._get_tool_config_class()
return config_class(permission=permission)
def resolve_permission(self, args: ToolArgs) -> PermissionContext | None:
"""Per-invocation permission override, checked before config-level permission.
Returns:
PermissionContext with granular required_permissions and a permission
level (ALWAYS/NEVER/ASK), or None to fall through to config permission.
Override in subclasses for domain-specific rules (e.g. workdir checks).
"""
return None
def get_file_snapshot(self, args: ToolArgs) -> FileSnapshot | None:
"""Return a snapshot of the file this tool is about to modify.
Called before ``run()`` so the checkpoint system can capture
the file's state *before* the tool writes to it.
Override in tools that modify files on disk.
"""
return None
@staticmethod
def get_file_snapshot_for_path(path: str) -> FileSnapshot:
file_path = Path(path).expanduser()
if not file_path.is_absolute():
file_path = Path.cwd() / file_path
file_path = file_path.resolve()
try:
content: bytes | None = file_path.read_bytes()
except FileNotFoundError:
content = None
except Exception:
logger.warning("Failed to read file for tool snapshot: %s", file_path)
content = None
return FileSnapshot(path=str(file_path), content=content)
def get_result_extra(self, result: ToolResult) -> str | None:
"""Optional extra context appended to the result text sent to the LLM.
Override in subclasses to inject contextual information alongside
tool results (e.g. directory-level instructions discovered during
file reads). The default returns ``None`` (no annotation).
"""
return None