mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Quentin Torroba <quentin.torroba@mistral.ai> Co-authored-by: Clément Siriex <clement.sirieix@mistral.ai> Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai> Co-authored-by: Michel Thomazo <michel.thomazo@mistral.ai> Co-authored-by: Clément Drouin <clement.drouin@mistral.ai>
622 lines
20 KiB
Python
622 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import MutableMapping
|
|
from enum import StrEnum, auto
|
|
import os
|
|
from pathlib import Path
|
|
import re
|
|
import shlex
|
|
import tomllib
|
|
from typing import Annotated, Any, Literal
|
|
|
|
from dotenv import dotenv_values
|
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
from pydantic.fields import FieldInfo
|
|
from pydantic_core import to_jsonable_python
|
|
from pydantic_settings import (
|
|
BaseSettings,
|
|
PydanticBaseSettingsSource,
|
|
SettingsConfigDict,
|
|
)
|
|
import tomli_w
|
|
|
|
from vibe.core.paths.config_paths import CONFIG_DIR, CONFIG_FILE, PROMPTS_DIR
|
|
from vibe.core.paths.global_paths import (
|
|
GLOBAL_ENV_FILE,
|
|
GLOBAL_PROMPTS_DIR,
|
|
SESSION_LOG_DIR,
|
|
)
|
|
from vibe.core.prompts import SystemPrompt
|
|
from vibe.core.tools.base import BaseToolConfig
|
|
|
|
|
|
def load_dotenv_values(
|
|
env_path: Path = GLOBAL_ENV_FILE.path,
|
|
environ: MutableMapping[str, str] = os.environ,
|
|
) -> None:
|
|
# We allow FIFO path to support some environment management solutions (e.g. https://developer.1password.com/docs/environments/local-env-file/)
|
|
if not env_path.is_file() and not env_path.is_fifo():
|
|
return
|
|
|
|
env_vars = dotenv_values(env_path)
|
|
for key, value in env_vars.items():
|
|
if not value:
|
|
continue
|
|
environ.update({key: value})
|
|
|
|
|
|
class MissingAPIKeyError(RuntimeError):
|
|
def __init__(self, env_key: str, provider_name: str) -> None:
|
|
super().__init__(
|
|
f"Missing {env_key} environment variable for {provider_name} provider"
|
|
)
|
|
self.env_key = env_key
|
|
self.provider_name = provider_name
|
|
|
|
|
|
class MissingPromptFileError(RuntimeError):
|
|
def __init__(
|
|
self, system_prompt_id: str, prompt_dir: str, global_prompt_dir: str
|
|
) -> None:
|
|
extra_global_prompt_dir = (
|
|
f" or {global_prompt_dir}" if global_prompt_dir != prompt_dir else ""
|
|
)
|
|
|
|
super().__init__(
|
|
f"Invalid system_prompt_id value: '{system_prompt_id}'. "
|
|
f"Must be one of the available prompts ({', '.join(f'{p.name.lower()}' for p in SystemPrompt)}), "
|
|
f"or correspond to a .md file in {prompt_dir}{extra_global_prompt_dir}"
|
|
)
|
|
self.system_prompt_id = system_prompt_id
|
|
self.prompt_dir = prompt_dir
|
|
|
|
|
|
class WrongBackendError(RuntimeError):
|
|
def __init__(self, backend: Backend, is_mistral_api: bool) -> None:
|
|
super().__init__(
|
|
f"Wrong backend '{backend}' for {'' if is_mistral_api else 'non-'}"
|
|
f"mistral API. Use '{Backend.MISTRAL}' for mistral API and '{Backend.GENERIC}' for others."
|
|
)
|
|
self.backend = backend
|
|
self.is_mistral_api = is_mistral_api
|
|
|
|
|
|
class TomlFileSettingsSource(PydanticBaseSettingsSource):
|
|
def __init__(self, settings_cls: type[BaseSettings]) -> None:
|
|
super().__init__(settings_cls)
|
|
self.toml_data = self._load_toml()
|
|
|
|
def _load_toml(self) -> dict[str, Any]:
|
|
file = CONFIG_FILE.path
|
|
try:
|
|
with file.open("rb") as f:
|
|
return tomllib.load(f)
|
|
except FileNotFoundError:
|
|
return {}
|
|
except tomllib.TOMLDecodeError as e:
|
|
raise RuntimeError(f"Invalid TOML in {file}: {e}") from e
|
|
except OSError as e:
|
|
raise RuntimeError(f"Cannot read {file}: {e}") from e
|
|
|
|
def get_field_value(
|
|
self, field: FieldInfo, field_name: str
|
|
) -> tuple[Any, str, bool]:
|
|
return self.toml_data.get(field_name), field_name, False
|
|
|
|
def __call__(self) -> dict[str, Any]:
|
|
return self.toml_data
|
|
|
|
|
|
class ProjectContextConfig(BaseSettings):
|
|
max_chars: int = 40_000
|
|
default_commit_count: int = 5
|
|
max_doc_bytes: int = 32 * 1024
|
|
truncation_buffer: int = 1_000
|
|
max_depth: int = 3
|
|
max_files: int = 1000
|
|
max_dirs_per_level: int = 20
|
|
timeout_seconds: float = 2.0
|
|
|
|
|
|
class SessionLoggingConfig(BaseSettings):
|
|
save_dir: str = ""
|
|
session_prefix: str = "session"
|
|
enabled: bool = True
|
|
|
|
@field_validator("save_dir", mode="before")
|
|
@classmethod
|
|
def set_default_save_dir(cls, v: str) -> str:
|
|
if not v:
|
|
return str(SESSION_LOG_DIR.path)
|
|
return v
|
|
|
|
@field_validator("save_dir", mode="after")
|
|
@classmethod
|
|
def expand_save_dir(cls, v: str) -> str:
|
|
return str(Path(v).expanduser().resolve())
|
|
|
|
|
|
class Backend(StrEnum):
|
|
MISTRAL = auto()
|
|
GENERIC = auto()
|
|
|
|
|
|
class ProviderConfig(BaseModel):
|
|
name: str
|
|
api_base: str
|
|
api_key_env_var: str = ""
|
|
api_style: str = "openai"
|
|
backend: Backend = Backend.GENERIC
|
|
reasoning_field_name: str = "reasoning_content"
|
|
project_id: str = ""
|
|
region: str = ""
|
|
|
|
|
|
class _MCPBase(BaseModel):
|
|
name: str = Field(description="Short alias used to prefix tool names")
|
|
prompt: str | None = Field(
|
|
default=None, description="Optional usage hint appended to tool descriptions"
|
|
)
|
|
startup_timeout_sec: float = Field(
|
|
default=10.0,
|
|
gt=0,
|
|
description="Timeout in seconds for the server to start and initialize.",
|
|
)
|
|
tool_timeout_sec: float = Field(
|
|
default=60.0, gt=0, description="Timeout in seconds for tool execution."
|
|
)
|
|
|
|
@field_validator("name", mode="after")
|
|
@classmethod
|
|
def normalize_name(cls, v: str) -> str:
|
|
normalized = re.sub(r"[^a-zA-Z0-9_-]", "_", v)
|
|
normalized = normalized.strip("_-")
|
|
return normalized[:256]
|
|
|
|
|
|
class _MCPHttpFields(BaseModel):
|
|
url: str = Field(description="Base URL of the MCP HTTP server")
|
|
headers: dict[str, str] = Field(
|
|
default_factory=dict,
|
|
description=(
|
|
"Additional HTTP headers when using 'http' transport (e.g., Authorization or X-API-Key)."
|
|
),
|
|
)
|
|
api_key_env: str = Field(
|
|
default="",
|
|
description=(
|
|
"Environment variable name containing an API token to send for HTTP transport."
|
|
),
|
|
)
|
|
api_key_header: str = Field(
|
|
default="Authorization",
|
|
description=(
|
|
"HTTP header name to carry the token when 'api_key_env' is set (e.g., 'Authorization' or 'X-API-Key')."
|
|
),
|
|
)
|
|
api_key_format: str = Field(
|
|
default="Bearer {token}",
|
|
description=(
|
|
"Format string for the header value when 'api_key_env' is set. Use '{token}' placeholder."
|
|
),
|
|
)
|
|
|
|
def http_headers(self) -> dict[str, str]:
|
|
hdrs = dict(self.headers or {})
|
|
env_var = (self.api_key_env or "").strip()
|
|
if env_var and (token := os.getenv(env_var)):
|
|
target = (self.api_key_header or "").strip() or "Authorization"
|
|
if not any(h.lower() == target.lower() for h in hdrs):
|
|
try:
|
|
value = (self.api_key_format or "{token}").format(token=token)
|
|
except Exception:
|
|
value = token
|
|
hdrs[target] = value
|
|
return hdrs
|
|
|
|
|
|
class MCPHttp(_MCPBase, _MCPHttpFields):
|
|
transport: Literal["http"]
|
|
|
|
|
|
class MCPStreamableHttp(_MCPBase, _MCPHttpFields):
|
|
transport: Literal["streamable-http"]
|
|
|
|
|
|
class MCPStdio(_MCPBase):
|
|
transport: Literal["stdio"]
|
|
command: str | list[str]
|
|
args: list[str] = Field(default_factory=list)
|
|
env: dict[str, str] = Field(
|
|
default_factory=dict,
|
|
description="Environment variables to set for the MCP server process.",
|
|
)
|
|
|
|
def argv(self) -> list[str]:
|
|
base = (
|
|
shlex.split(self.command)
|
|
if isinstance(self.command, str)
|
|
else list(self.command or [])
|
|
)
|
|
return [*base, *self.args] if self.args else base
|
|
|
|
|
|
MCPServer = Annotated[
|
|
MCPHttp | MCPStreamableHttp | MCPStdio, Field(discriminator="transport")
|
|
]
|
|
|
|
|
|
class ModelConfig(BaseModel):
|
|
name: str
|
|
provider: str
|
|
alias: str
|
|
temperature: float = 0.2
|
|
input_price: float = 0.0 # Price per million input tokens
|
|
output_price: float = 0.0 # Price per million output tokens
|
|
thinking: Literal["off", "low", "medium", "high"] = "off"
|
|
|
|
@model_validator(mode="before")
|
|
@classmethod
|
|
def _default_alias_to_name(cls, data: Any) -> Any:
|
|
if isinstance(data, dict):
|
|
if "alias" not in data or data["alias"] is None:
|
|
data["alias"] = data.get("name")
|
|
return data
|
|
|
|
|
|
DEFAULT_MISTRAL_API_ENV_KEY = "MISTRAL_API_KEY"
|
|
|
|
|
|
DEFAULT_PROVIDERS = [
|
|
ProviderConfig(
|
|
name="mistral",
|
|
api_base="https://api.mistral.ai/v1",
|
|
api_key_env_var=DEFAULT_MISTRAL_API_ENV_KEY,
|
|
backend=Backend.MISTRAL,
|
|
),
|
|
ProviderConfig(
|
|
name="llamacpp",
|
|
api_base="http://127.0.0.1:8080/v1",
|
|
api_key_env_var="", # NOTE: if you wish to use --api-key in llama-server, change this value
|
|
),
|
|
]
|
|
|
|
DEFAULT_MODELS = [
|
|
ModelConfig(
|
|
name="mistral-vibe-cli-latest",
|
|
provider="mistral",
|
|
alias="devstral-2",
|
|
input_price=0.4,
|
|
output_price=2.0,
|
|
),
|
|
ModelConfig(
|
|
name="devstral-small-latest",
|
|
provider="mistral",
|
|
alias="devstral-small",
|
|
input_price=0.1,
|
|
output_price=0.3,
|
|
),
|
|
ModelConfig(
|
|
name="devstral",
|
|
provider="llamacpp",
|
|
alias="local",
|
|
input_price=0.0,
|
|
output_price=0.0,
|
|
),
|
|
]
|
|
|
|
|
|
class VibeConfig(BaseSettings):
|
|
active_model: str = "devstral-2"
|
|
vim_keybindings: bool = False
|
|
disable_welcome_banner_animation: bool = False
|
|
autocopy_to_clipboard: bool = True
|
|
displayed_workdir: str = ""
|
|
auto_compact_threshold: int = 200_000
|
|
context_warnings: bool = False
|
|
auto_approve: bool = False
|
|
enable_telemetry: bool = True
|
|
system_prompt_id: str = "cli"
|
|
include_commit_signature: bool = True
|
|
include_model_info: bool = True
|
|
include_project_context: bool = True
|
|
include_prompt_detail: bool = True
|
|
enable_update_checks: bool = True
|
|
enable_auto_update: bool = True
|
|
api_timeout: float = 720.0
|
|
|
|
# TODO(vibe-nuage): remove exclude=True once the feature is publicly available
|
|
nuage_enabled: bool = Field(default=False, exclude=True)
|
|
nuage_base_url: str = Field(default="https://api.globalaegis.net", exclude=True)
|
|
nuage_workflow_id: str = Field(default="__shared-nuage-workflow", exclude=True)
|
|
# TODO(vibe-nuage): change default value to MISTRAL_API_KEY once prod has shared vibe-nuage workers
|
|
nuage_api_key_env_var: str = Field(default="STAGING_MISTRAL_API_KEY", exclude=True)
|
|
|
|
providers: list[ProviderConfig] = Field(
|
|
default_factory=lambda: list(DEFAULT_PROVIDERS)
|
|
)
|
|
models: list[ModelConfig] = Field(default_factory=lambda: list(DEFAULT_MODELS))
|
|
|
|
project_context: ProjectContextConfig = Field(default_factory=ProjectContextConfig)
|
|
session_logging: SessionLoggingConfig = Field(default_factory=SessionLoggingConfig)
|
|
tools: dict[str, BaseToolConfig] = Field(default_factory=dict)
|
|
tool_paths: list[Path] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"Additional directories or files to explore for custom tools. "
|
|
"Paths may be absolute or relative to the current working directory. "
|
|
"Directories are shallow-searched for tool definition files, "
|
|
"while files are loaded directly if valid."
|
|
),
|
|
)
|
|
|
|
mcp_servers: list[MCPServer] = Field(
|
|
default_factory=list, description="Preferred MCP server configuration entries."
|
|
)
|
|
|
|
enabled_tools: list[str] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"An explicit list of tool names/patterns to enable. If set, only these"
|
|
" tools will be active. Supports glob patterns (e.g., 'serena_*') and"
|
|
" regex with 're:' prefix (e.g., 're:^serena_.*')."
|
|
),
|
|
)
|
|
disabled_tools: list[str] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"A list of tool names/patterns to disable. Ignored if 'enabled_tools'"
|
|
" is set. Supports glob patterns and regex with 're:' prefix."
|
|
),
|
|
)
|
|
agent_paths: list[Path] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"Additional directories to search for custom agent profiles. "
|
|
"Each path may be absolute or relative to the current working directory."
|
|
),
|
|
)
|
|
enabled_agents: list[str] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"An explicit list of agent names/patterns to enable. If set, only these"
|
|
" agents will be available. Supports glob patterns (e.g., 'custom-*')"
|
|
" and regex with 're:' prefix."
|
|
),
|
|
)
|
|
disabled_agents: list[str] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"A list of agent names/patterns to disable. Ignored if 'enabled_agents'"
|
|
" is set. Supports glob patterns and regex with 're:' prefix."
|
|
),
|
|
)
|
|
skill_paths: list[Path] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"Additional directories to search for skills. "
|
|
"Each path may be absolute or relative to the current working directory."
|
|
),
|
|
)
|
|
enabled_skills: list[str] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"An explicit list of skill names/patterns to enable. If set, only these"
|
|
" skills will be active. Supports glob patterns (e.g., 'search-*') and"
|
|
" regex with 're:' prefix."
|
|
),
|
|
)
|
|
disabled_skills: list[str] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"A list of skill names/patterns to disable. Ignored if 'enabled_skills'"
|
|
" is set. Supports glob patterns and regex with 're:' prefix."
|
|
),
|
|
)
|
|
|
|
model_config = SettingsConfigDict(
|
|
env_prefix="VIBE_", case_sensitive=False, extra="ignore"
|
|
)
|
|
|
|
@property
|
|
def nuage_api_key(self) -> str:
|
|
return os.getenv(self.nuage_api_key_env_var, "")
|
|
|
|
@property
|
|
def system_prompt(self) -> str:
|
|
try:
|
|
return SystemPrompt[self.system_prompt_id.upper()].read()
|
|
except KeyError:
|
|
pass
|
|
|
|
for current_prompt_dir in [PROMPTS_DIR.path, GLOBAL_PROMPTS_DIR.path]:
|
|
custom_sp_path = (current_prompt_dir / self.system_prompt_id).with_suffix(
|
|
".md"
|
|
)
|
|
if custom_sp_path.is_file():
|
|
return custom_sp_path.read_text()
|
|
|
|
raise MissingPromptFileError(
|
|
self.system_prompt_id, str(PROMPTS_DIR.path), str(GLOBAL_PROMPTS_DIR.path)
|
|
)
|
|
|
|
def get_active_model(self) -> ModelConfig:
|
|
for model in self.models:
|
|
if model.alias == self.active_model:
|
|
return model
|
|
raise ValueError(
|
|
f"Active model '{self.active_model}' not found in configuration."
|
|
)
|
|
|
|
def get_provider_for_model(self, model: ModelConfig) -> ProviderConfig:
|
|
for provider in self.providers:
|
|
if provider.name == model.provider:
|
|
return provider
|
|
raise ValueError(
|
|
f"Provider '{model.provider}' for model '{model.name}' not found in configuration."
|
|
)
|
|
|
|
@classmethod
|
|
def settings_customise_sources(
|
|
cls,
|
|
settings_cls: type[BaseSettings],
|
|
init_settings: PydanticBaseSettingsSource,
|
|
env_settings: PydanticBaseSettingsSource,
|
|
dotenv_settings: PydanticBaseSettingsSource,
|
|
file_secret_settings: PydanticBaseSettingsSource,
|
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
|
"""Define the priority of settings sources.
|
|
|
|
Note: dotenv_settings is intentionally excluded. API keys and other
|
|
non-config environment variables are stored in .env but loaded manually
|
|
into os.environ for use by providers. Only VIBE_* prefixed environment
|
|
variables (via env_settings) and TOML config are used for Pydantic settings.
|
|
"""
|
|
return (
|
|
init_settings,
|
|
env_settings,
|
|
TomlFileSettingsSource(settings_cls),
|
|
file_secret_settings,
|
|
)
|
|
|
|
@model_validator(mode="after")
|
|
def _check_api_key(self) -> VibeConfig:
|
|
try:
|
|
active_model = self.get_active_model()
|
|
provider = self.get_provider_for_model(active_model)
|
|
api_key_env = provider.api_key_env_var
|
|
if api_key_env and not os.getenv(api_key_env):
|
|
raise MissingAPIKeyError(api_key_env, provider.name)
|
|
except ValueError:
|
|
pass
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def _check_api_backend_compatibility(self) -> VibeConfig:
|
|
try:
|
|
active_model = self.get_active_model()
|
|
provider = self.get_provider_for_model(active_model)
|
|
MISTRAL_API_BASES = [
|
|
"https://codestral.mistral.ai",
|
|
"https://api.mistral.ai",
|
|
]
|
|
is_mistral_api = any(
|
|
provider.api_base.startswith(api_base) for api_base in MISTRAL_API_BASES
|
|
)
|
|
if (is_mistral_api and provider.backend != Backend.MISTRAL) or (
|
|
not is_mistral_api and provider.backend != Backend.GENERIC
|
|
):
|
|
raise WrongBackendError(provider.backend, is_mistral_api)
|
|
|
|
except ValueError:
|
|
pass
|
|
return self
|
|
|
|
@field_validator("tool_paths", mode="before")
|
|
@classmethod
|
|
def _expand_tool_paths(cls, v: Any) -> list[Path]:
|
|
if not v:
|
|
return []
|
|
return [Path(p).expanduser().resolve() for p in v]
|
|
|
|
@field_validator("skill_paths", mode="before")
|
|
@classmethod
|
|
def _expand_skill_paths(cls, v: Any) -> list[Path]:
|
|
if not v:
|
|
return []
|
|
return [Path(p).expanduser().resolve() for p in v]
|
|
|
|
@field_validator("tools", mode="before")
|
|
@classmethod
|
|
def _normalize_tool_configs(cls, v: Any) -> dict[str, BaseToolConfig]:
|
|
if not isinstance(v, dict):
|
|
return {}
|
|
|
|
normalized: dict[str, BaseToolConfig] = {}
|
|
for tool_name, tool_config in v.items():
|
|
if isinstance(tool_config, BaseToolConfig):
|
|
normalized[tool_name] = tool_config
|
|
elif isinstance(tool_config, dict):
|
|
normalized[tool_name] = BaseToolConfig.model_validate(tool_config)
|
|
else:
|
|
normalized[tool_name] = BaseToolConfig()
|
|
|
|
return normalized
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_model_uniqueness(self) -> VibeConfig:
|
|
seen_aliases: set[str] = set()
|
|
for model in self.models:
|
|
if model.alias in seen_aliases:
|
|
raise ValueError(
|
|
f"Duplicate model alias found: '{model.alias}'. Aliases must be unique."
|
|
)
|
|
seen_aliases.add(model.alias)
|
|
return self
|
|
|
|
@model_validator(mode="after")
|
|
def _check_system_prompt(self) -> VibeConfig:
|
|
_ = self.system_prompt
|
|
return self
|
|
|
|
@classmethod
|
|
def save_updates(cls, updates: dict[str, Any]) -> None:
|
|
CONFIG_DIR.path.mkdir(parents=True, exist_ok=True)
|
|
current_config = TomlFileSettingsSource(cls).toml_data
|
|
|
|
def deep_merge(target: dict, source: dict) -> None:
|
|
for key, value in source.items():
|
|
if (
|
|
key in target
|
|
and isinstance(target.get(key), dict)
|
|
and isinstance(value, dict)
|
|
):
|
|
deep_merge(target[key], value)
|
|
elif (
|
|
key in target
|
|
and isinstance(target.get(key), list)
|
|
and isinstance(value, list)
|
|
):
|
|
if key in {"providers", "models"}:
|
|
target[key] = value
|
|
else:
|
|
target[key] = list(set(value + target[key]))
|
|
else:
|
|
target[key] = value
|
|
|
|
deep_merge(current_config, updates)
|
|
cls.dump_config(
|
|
to_jsonable_python(current_config, exclude_none=True, fallback=str)
|
|
)
|
|
|
|
@classmethod
|
|
def dump_config(cls, config: dict[str, Any]) -> None:
|
|
with CONFIG_FILE.path.open("wb") as f:
|
|
tomli_w.dump(config, f)
|
|
|
|
@classmethod
|
|
def _migrate(cls) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def load(cls, **overrides: Any) -> VibeConfig:
|
|
cls._migrate()
|
|
return cls(**(overrides or {}))
|
|
|
|
@classmethod
|
|
def create_default(cls) -> dict[str, Any]:
|
|
try:
|
|
config = cls()
|
|
except MissingAPIKeyError:
|
|
config = cls.model_construct()
|
|
|
|
config_dict = config.model_dump(mode="json", exclude_none=True)
|
|
|
|
from vibe.core.tools.manager import ToolManager
|
|
|
|
tool_defaults = ToolManager.discover_tool_defaults()
|
|
if tool_defaults:
|
|
config_dict["tools"] = tool_defaults
|
|
|
|
return config_dict
|