Files
mistral-vibe/tests/tools/test_mcp.py
Clément Sirieix 04305bd77c v2.8.0 (#616)
Co-authored-by: Kim-Adeline Miguel <51720070+kimadeline@users.noreply.github.com>
Co-authored-by: Mathias Gesbert <mathias.gesbert@mistral.ai>
Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com>
Co-authored-by: Pierre Rossinès <pierre.rossines@mistral.ai>
Co-authored-by: Mistral Vibe <vibe@mistral.ai>
2026-04-21 15:19:59 +02:00

713 lines
24 KiB
Python

from __future__ import annotations
import logging
import os
import threading
import time
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, patch
from pydantic import ValidationError
import pytest
from vibe.core.config import MCPHttp, MCPStdio, MCPStreamableHttp
from vibe.core.tools.mcp import (
MCPRegistry,
MCPToolResult,
RemoteTool,
_mcp_stderr_capture,
_parse_call_result,
_stderr_logger_thread,
call_tool_stdio,
create_mcp_http_proxy_tool_class,
create_mcp_stdio_proxy_tool_class,
list_tools_stdio,
)
class TestRemoteTool:
def test_creates_remote_tool_with_valid_data(self):
tool = RemoteTool.model_validate({
"name": "test_tool",
"description": "A test tool",
"inputSchema": {
"type": "object",
"properties": {"arg": {"type": "string"}},
},
})
assert tool.name == "test_tool"
assert tool.description == "A test tool"
assert tool.input_schema == {
"type": "object",
"properties": {"arg": {"type": "string"}},
}
def test_uses_default_schema_when_none_provided(self):
tool = RemoteTool(name="test_tool")
assert tool.input_schema == {"type": "object", "properties": {}}
def test_rejects_empty_name(self):
with pytest.raises(ValueError, match="MCP tool missing valid 'name'"):
RemoteTool(name="")
def test_rejects_whitespace_only_name(self):
with pytest.raises(ValueError, match="MCP tool missing valid 'name'"):
RemoteTool(name=" ")
def test_normalizes_schema_from_object_with_model_dump(self):
mock_schema = MagicMock()
mock_schema.model_dump.return_value = {"type": "string"}
tool = RemoteTool.model_validate({"name": "test", "inputSchema": mock_schema})
assert tool.input_schema == {"type": "string"}
def test_rejects_invalid_input_schema(self):
with pytest.raises(ValueError, match="inputSchema must be a dict"):
RemoteTool.model_validate({"name": "test", "inputSchema": 12345})
class TestMCPToolResult:
def test_creates_result_with_text(self):
result = MCPToolResult(server="test_server", tool="test_tool", text="output")
assert result.ok is True
assert result.server == "test_server"
assert result.tool == "test_tool"
assert result.text == "output"
assert result.structured is None
def test_creates_result_with_structured_content(self):
result = MCPToolResult(
server="test_server", tool="test_tool", structured={"key": "value"}
)
assert result.structured == {"key": "value"}
assert result.text is None
class TestParseCallResult:
def test_parses_text_content(self):
mock_result = MagicMock()
mock_result.structuredContent = None
mock_result.content = [MagicMock(text="Hello world")]
result = _parse_call_result("server", "tool", mock_result)
assert result.server == "server"
assert result.tool == "tool"
assert result.text == "Hello world"
assert result.structured is None
def test_parses_structured_content(self):
mock_result = MagicMock()
mock_result.structuredContent = {"data": "value"}
mock_result.content = None
result = _parse_call_result("server", "tool", mock_result)
assert result.structured == {"data": "value"}
assert result.text is None
def test_prefers_structured_over_text(self):
mock_result = MagicMock()
mock_result.structuredContent = {"data": "value"}
mock_result.content = [MagicMock(text="text content")]
result = _parse_call_result("server", "tool", mock_result)
assert result.structured == {"data": "value"}
assert result.text is None
def test_joins_multiple_text_blocks(self):
mock_result = MagicMock()
mock_result.structuredContent = None
mock_result.content = [MagicMock(text="line1"), MagicMock(text="line2")]
result = _parse_call_result("server", "tool", mock_result)
assert result.text == "line1\nline2"
class TestMCPStderrCapture:
"""Tests for _mcp_stderr_capture and _stderr_logger_thread."""
@pytest.mark.asyncio
async def test_mcp_stderr_capture_returns_writable_stream(self):
async with _mcp_stderr_capture() as stream:
assert stream is not None
assert callable(getattr(stream, "write", None))
stream.write("test\n")
def test_stderr_logger_thread_logs_decoded_lines(self):
r_fd, w_fd = os.pipe()
try:
vibe_logger = logging.getLogger("vibe")
with patch.object(vibe_logger, "debug") as debug_mock:
thread = threading.Thread(
target=_stderr_logger_thread, args=(r_fd,), daemon=True
)
thread.start()
try:
w = os.fdopen(w_fd, "wb")
w_fd = -1
w.write(b"hello stderr\n")
w.write(b"second line\n")
w.close()
w = None
finally:
time.sleep(0.05)
debug_mock.assert_any_call("[MCP stderr] hello stderr")
debug_mock.assert_any_call("[MCP stderr] second line")
finally:
if w_fd >= 0:
try:
os.close(w_fd)
except OSError:
pass
try:
os.close(r_fd)
except OSError:
pass
@pytest.mark.asyncio
async def test_mcp_stderr_capture_logs_written_data(self):
vibe_logger = logging.getLogger("vibe")
with patch.object(vibe_logger, "debug") as debug_mock:
async with _mcp_stderr_capture() as stream:
stream.write("captured line\n")
time.sleep(0.05)
debug_mock.assert_called_with("[MCP stderr] captured line")
@pytest.mark.asyncio
async def test_mcp_stderr_capture_ignores_empty_lines(self):
vibe_logger = logging.getLogger("vibe")
with patch.object(vibe_logger, "debug") as debug_mock:
async with _mcp_stderr_capture() as stream:
stream.write("\n\n")
time.sleep(0.05)
debug_mock.assert_not_called()
class TestCreateMCPHttpProxyToolClass:
def test_creates_tool_class_with_correct_name(self):
remote = RemoteTool(name="my_tool", description="Test tool")
tool_cls = create_mcp_http_proxy_tool_class(
url="http://localhost:8080", remote=remote, alias="test_server"
)
assert tool_cls.get_name() == "test_server_my_tool"
def test_creates_tool_class_with_url_based_alias(self):
remote = RemoteTool(name="my_tool")
tool_cls = create_mcp_http_proxy_tool_class(
url="http://localhost:8080", remote=remote
)
assert tool_cls.get_name() == "localhost_8080_my_tool"
def test_includes_description_with_hint(self):
remote = RemoteTool(name="my_tool", description="Base description")
tool_cls = create_mcp_http_proxy_tool_class(
url="http://localhost:8080",
remote=remote,
alias="test",
server_hint="Use this for testing",
)
assert "[test]" in tool_cls.description
assert "Base description" in tool_cls.description
assert "Hint: Use this for testing" in tool_cls.description
def test_stores_timeout_settings(self):
remote = RemoteTool(name="my_tool")
tool_cls = create_mcp_http_proxy_tool_class(
url="http://localhost:8080",
remote=remote,
startup_timeout_sec=30.0,
tool_timeout_sec=120.0,
)
assert tool_cls._startup_timeout_sec == 30.0 # type: ignore[attr-defined]
assert tool_cls._tool_timeout_sec == 120.0 # type: ignore[attr-defined]
def test_returns_correct_parameters(self):
remote = RemoteTool.model_validate({
"name": "my_tool",
"inputSchema": {
"type": "object",
"properties": {"arg": {"type": "string"}},
},
})
tool_cls = create_mcp_http_proxy_tool_class(
url="http://localhost:8080", remote=remote
)
params = tool_cls.get_parameters()
assert params == {"type": "object", "properties": {"arg": {"type": "string"}}}
class TestCreateMCPStdioProxyToolClass:
def test_creates_tool_class_with_alias(self):
remote = RemoteTool(name="my_tool")
tool_cls = create_mcp_stdio_proxy_tool_class(
command=["python", "-m", "mcp_server"], remote=remote, alias="my_server"
)
assert tool_cls.get_name() == "my_server_my_tool"
def test_creates_tool_class_with_command_based_alias(self):
remote = RemoteTool(name="my_tool")
tool_cls = create_mcp_stdio_proxy_tool_class(
command=["python", "-m", "mcp_server"], remote=remote
)
name = tool_cls.get_name()
assert name.startswith("python_")
assert name.endswith("_my_tool")
def test_stores_env_settings(self):
remote = RemoteTool(name="my_tool")
tool_cls = create_mcp_stdio_proxy_tool_class(
command=["python", "-m", "mcp_server"],
remote=remote,
env={"API_KEY": "secret"},
)
assert tool_cls._env == {"API_KEY": "secret"} # type: ignore[attr-defined]
def test_stores_timeout_settings(self):
remote = RemoteTool(name="my_tool")
tool_cls = create_mcp_stdio_proxy_tool_class(
command=["python", "-m", "mcp_server"],
remote=remote,
startup_timeout_sec=15.0,
tool_timeout_sec=90.0,
)
assert tool_cls._startup_timeout_sec == 15.0 # type: ignore[attr-defined]
assert tool_cls._tool_timeout_sec == 90.0 # type: ignore[attr-defined]
def test_includes_hint_in_description(self):
remote = RemoteTool(name="my_tool", description="Base description")
tool_cls = create_mcp_stdio_proxy_tool_class(
command=["python"],
remote=remote,
alias="test",
server_hint="For testing only",
)
assert "Hint: For testing only" in tool_cls.description
class TestMCPConfigModels:
def test_mcp_base_default_timeouts(self):
config = MCPStdio(
name="test", transport="stdio", command="python -m test_server"
)
assert config.startup_timeout_sec == 10.0
assert config.tool_timeout_sec == 60.0
def test_mcp_base_custom_timeouts(self):
config = MCPStdio(
name="test",
transport="stdio",
command="python -m test_server",
startup_timeout_sec=30.0,
tool_timeout_sec=120.0,
)
assert config.startup_timeout_sec == 30.0
assert config.tool_timeout_sec == 120.0
def test_mcp_base_rejects_non_positive_timeout(self):
with pytest.raises(ValidationError):
MCPStdio(
name="test", transport="stdio", command="python", startup_timeout_sec=0
)
def test_mcp_stdio_with_env(self):
config = MCPStdio(
name="test",
transport="stdio",
command="python -m server",
env={"API_KEY": "secret", "DEBUG": "1"},
)
assert config.env == {"API_KEY": "secret", "DEBUG": "1"}
def test_mcp_stdio_argv_with_string_command(self):
config = MCPStdio(
name="test", transport="stdio", command="python -m server --port 8080"
)
assert config.argv() == ["python", "-m", "server", "--port", "8080"]
def test_mcp_stdio_argv_with_list_command(self):
config = MCPStdio(
name="test",
transport="stdio",
command=["python", "-m", "server"],
args=["--port", "8080"],
)
assert config.argv() == ["python", "-m", "server", "--port", "8080"]
def test_mcp_http_default_timeouts(self):
config = MCPHttp(name="test", transport="http", url="http://localhost:8080")
assert config.startup_timeout_sec == 10.0
assert config.tool_timeout_sec == 60.0
def test_mcp_streamable_http_default_timeouts(self):
config = MCPStreamableHttp(
name="test", transport="streamable-http", url="http://localhost:8080"
)
assert config.startup_timeout_sec == 10.0
assert config.tool_timeout_sec == 60.0
def test_mcp_name_normalization(self):
config = MCPStdio(name="my server!@#$%", transport="stdio", command="python")
# Trailing special chars become underscores which are then stripped
assert config.name == "my_server"
class TestMCPRegistry:
def _make_http_server(
self, name: str, url: str = "http://localhost:8080"
) -> MCPHttp:
return MCPHttp(name=name, transport="http", url=url)
def _make_stdio_server(self, name: str, command: str = "python -m srv") -> MCPStdio:
return MCPStdio(name=name, transport="stdio", command=command)
def test_server_key_is_stable(self):
srv = self._make_http_server("s1")
registry = MCPRegistry()
assert registry._server_key(srv) == registry._server_key(srv)
def test_different_configs_produce_different_keys(self):
registry = MCPRegistry()
s1 = self._make_http_server("s1", url="http://a:1")
s2 = self._make_http_server("s2", url="http://b:2")
assert registry._server_key(s1) != registry._server_key(s2)
def test_get_tools_caches_discovery(self):
registry = MCPRegistry()
srv = self._make_http_server("cached")
remote = RemoteTool(name="tool_a", description="A tool")
proxy = create_mcp_http_proxy_tool_class(
url="http://localhost:8080", remote=remote, alias="cached"
)
key = registry._server_key(srv)
registry._cache[key] = {proxy.get_name(): proxy}
tools = registry.get_tools([srv])
assert "cached_tool_a" in tools
assert tools["cached_tool_a"] is proxy
def test_get_tools_returns_empty_for_no_servers(self):
registry = MCPRegistry()
assert registry.get_tools([]) == {}
def test_clear_drops_cache(self):
registry = MCPRegistry()
srv = self._make_http_server("s")
proxy = create_mcp_http_proxy_tool_class(
url="http://localhost:8080", remote=RemoteTool(name="t"), alias="s"
)
key = registry._server_key(srv)
registry._cache[key] = {proxy.get_name(): proxy}
registry.clear()
assert len(registry._cache) == 0
def test_count_loaded_excludes_failed_servers(self):
registry = MCPRegistry()
ok_srv = self._make_http_server("ok", url="http://ok:1")
fail_srv = self._make_http_server("fail", url="http://fail:2")
proxy = create_mcp_http_proxy_tool_class(
url="http://ok:1", remote=RemoteTool(name="t"), alias="ok"
)
registry._cache[registry._server_key(ok_srv)] = {proxy.get_name(): proxy}
assert registry.count_loaded([ok_srv, fail_srv]) == 1
assert registry.count_loaded([ok_srv]) == 1
assert registry.count_loaded([fail_srv]) == 0
assert registry.count_loaded([]) == 0
def test_cache_survives_multiple_get_tools_calls(self):
registry = MCPRegistry()
srv = self._make_http_server("stable")
remote = RemoteTool(name="t1")
proxy = create_mcp_http_proxy_tool_class(
url="http://localhost:8080", remote=remote, alias="stable"
)
key = registry._server_key(srv)
registry._cache[key] = {proxy.get_name(): proxy}
first = registry.get_tools([srv])
second = registry.get_tools([srv])
assert first == second
assert first["stable_t1"] is second["stable_t1"]
def test_disjoint_server_lists_across_agents(self):
registry = MCPRegistry()
srv_x = self._make_http_server("x", url="http://x:1")
srv_y = self._make_http_server("y", url="http://y:2")
proxy_x = create_mcp_http_proxy_tool_class(
url="http://x:1", remote=RemoteTool(name="tx"), alias="x"
)
proxy_y = create_mcp_http_proxy_tool_class(
url="http://y:2", remote=RemoteTool(name="ty"), alias="y"
)
registry._cache[registry._server_key(srv_x)] = {proxy_x.get_name(): proxy_x}
registry._cache[registry._server_key(srv_y)] = {proxy_y.get_name(): proxy_y}
agent_a_tools = registry.get_tools([srv_x])
agent_b_tools = registry.get_tools([srv_y])
assert "x_tx" in agent_a_tools
assert "y_ty" not in agent_a_tools
assert "y_ty" in agent_b_tools
assert "x_tx" not in agent_b_tools
@pytest.mark.asyncio
async def test_discover_http_success(self):
registry = MCPRegistry()
srv = self._make_http_server("demo", url="http://demo:9090")
remote = RemoteTool(name="hello", description="Hi")
with patch(
"vibe.core.tools.mcp.registry.list_tools_http", return_value=[remote]
):
tools = await registry._discover_http(srv)
assert tools is not None
assert len(tools) == 1
name = next(iter(tools))
assert name == "demo_hello"
@pytest.mark.asyncio
async def test_discover_http_failure_returns_none(self):
registry = MCPRegistry()
srv = self._make_http_server("fail", url="http://fail:1")
with patch(
"vibe.core.tools.mcp.registry.list_tools_http",
side_effect=ConnectionError("down"),
):
tools = await registry._discover_http(srv)
assert tools is None
@pytest.mark.asyncio
async def test_discover_stdio_success(self):
registry = MCPRegistry()
srv = self._make_stdio_server("local", command="python -m local_srv")
remote = RemoteTool(name="run", description="Run it")
with patch(
"vibe.core.tools.mcp.registry.list_tools_stdio", return_value=[remote]
):
tools = await registry._discover_stdio(srv)
assert tools is not None
assert len(tools) == 1
name = next(iter(tools))
assert name == "local_run"
@pytest.mark.asyncio
async def test_discover_stdio_failure_returns_none(self):
registry = MCPRegistry()
srv = self._make_stdio_server("broken")
with patch(
"vibe.core.tools.mcp.registry.list_tools_stdio",
side_effect=OSError("no binary"),
):
tools = await registry._discover_stdio(srv)
assert tools is None
def test_get_tools_discovers_only_uncached(self):
registry = MCPRegistry()
cached_srv = self._make_http_server("cached", url="http://c:1")
new_srv = self._make_http_server("new", url="http://n:2")
cached_proxy = create_mcp_http_proxy_tool_class(
url="http://c:1", remote=RemoteTool(name="ct"), alias="cached"
)
registry._cache[registry._server_key(cached_srv)] = {
cached_proxy.get_name(): cached_proxy
}
new_remote = RemoteTool(name="nt")
with patch(
"vibe.core.tools.mcp.registry.list_tools_http", return_value=[new_remote]
):
tools = registry.get_tools([cached_srv, new_srv])
assert "cached_ct" in tools
assert "new_nt" in tools
assert len(registry._cache) == 2
class TestMCPStdioCwd:
def test_mcp_stdio_cwd_defaults_to_none(self):
config = MCPStdio(name="test", transport="stdio", command="python -m srv")
assert config.cwd is None
def test_mcp_stdio_cwd_accepts_string(self):
config = MCPStdio(
name="test",
transport="stdio",
command="python -m srv",
cwd="/tmp/myproject",
)
assert config.cwd == "/tmp/myproject"
@pytest.mark.asyncio
async def test_list_tools_stdio_passes_cwd_to_params(self):
with (
patch("vibe.core.tools.mcp.tools.stdio_client") as mock_client,
patch("vibe.core.tools.mcp.tools.ClientSession") as mock_session_cls,
patch("vibe.core.tools.mcp.tools.StdioServerParameters") as mock_params_cls,
):
mock_client.return_value.__aenter__ = AsyncMock(
return_value=(MagicMock(), MagicMock())
)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools = AsyncMock(return_value=MagicMock(tools=[]))
mock_session_cls.return_value.__aenter__ = AsyncMock(
return_value=mock_session
)
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
await list_tools_stdio(["python", "-m", "srv"], cwd="/tmp/myproject")
mock_params_cls.assert_called_once_with(
command="python", args=["-m", "srv"], env=None, cwd="/tmp/myproject"
)
@pytest.mark.asyncio
async def test_call_tool_stdio_passes_cwd_to_params(self):
with (
patch("vibe.core.tools.mcp.tools.stdio_client") as mock_client,
patch("vibe.core.tools.mcp.tools.ClientSession") as mock_session_cls,
patch("vibe.core.tools.mcp.tools.StdioServerParameters") as mock_params_cls,
patch("vibe.core.tools.mcp.tools._parse_call_result") as mock_parse,
):
mock_client.return_value.__aenter__ = AsyncMock(
return_value=(MagicMock(), MagicMock())
)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock())
mock_session_cls.return_value.__aenter__ = AsyncMock(
return_value=mock_session
)
mock_session_cls.return_value.__aexit__ = AsyncMock(return_value=False)
mock_parse.return_value = MagicMock(spec=MCPToolResult)
await call_tool_stdio(
["python", "-m", "srv"], "my_tool", {}, cwd="/tmp/myproject"
)
mock_params_cls.assert_called_once_with(
command="python", args=["-m", "srv"], env=None, cwd="/tmp/myproject"
)
@pytest.mark.asyncio
async def test_discover_stdio_passes_cwd_to_list_tools(self):
registry = MCPRegistry()
srv = MCPStdio(
name="local",
transport="stdio",
command="python -m srv",
cwd="/tmp/myproject",
)
remote = RemoteTool(name="run", description="Run it")
with patch(
"vibe.core.tools.mcp.registry.list_tools_stdio", return_value=[remote]
) as mock_list:
await registry._discover_stdio(srv)
mock_list.assert_called_once_with(
["python", "-m", "srv"],
env=None,
cwd="/tmp/myproject",
startup_timeout_sec=srv.startup_timeout_sec,
)
@pytest.mark.asyncio
async def test_discover_stdio_passes_cwd_to_proxy_class(self):
registry = MCPRegistry()
srv = MCPStdio(
name="local",
transport="stdio",
command="python -m srv",
cwd="/tmp/myproject",
)
remote = RemoteTool(name="run", description="Run it")
with (
patch(
"vibe.core.tools.mcp.registry.list_tools_stdio", return_value=[remote]
),
patch(
"vibe.core.tools.mcp.registry.create_mcp_stdio_proxy_tool_class",
wraps=create_mcp_stdio_proxy_tool_class,
) as mock_create,
):
await registry._discover_stdio(srv)
_, kwargs = mock_create.call_args
assert kwargs["cwd"] == "/tmp/myproject"
def test_proxy_tool_stores_cwd(self):
remote = RemoteTool(name="run")
proxy_cls = cast(
Any,
create_mcp_stdio_proxy_tool_class(
command=["python", "-m", "srv"], remote=remote, cwd="/tmp/myproject"
),
)
assert proxy_cls._cwd == "/tmp/myproject"
def test_proxy_tool_cwd_defaults_to_none(self):
remote = RemoteTool(name="run")
proxy_cls = cast(
Any,
create_mcp_stdio_proxy_tool_class(
command=["python", "-m", "srv"], remote=remote
),
)
assert proxy_cls._cwd is None