mirror of
https://github.com/browser-use/browser-use
synced 2026-04-22 17:45:09 +02:00
549 lines
17 KiB
Python
549 lines
17 KiB
Python
"""MCP (Model Context Protocol) client integration for browser-use.
|
|
|
|
This module provides integration between external MCP servers and browser-use's action registry.
|
|
MCP tools are dynamically discovered and registered as browser-use actions.
|
|
|
|
Example usage:
|
|
from browser_use import Tools
|
|
from browser_use.mcp.client import MCPClient
|
|
|
|
tools = Tools()
|
|
|
|
# Connect to an MCP server
|
|
mcp_client = MCPClient(
|
|
server_name="my-server",
|
|
command="npx",
|
|
args=["@mycompany/mcp-server@latest"]
|
|
)
|
|
|
|
# Register all MCP tools as browser-use actions
|
|
await mcp_client.register_to_tools(tools)
|
|
|
|
# Now use with Agent as normal - MCP tools are available as actions
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
|
|
from browser_use.agent.views import ActionResult
|
|
from browser_use.telemetry import MCPClientTelemetryEvent, ProductTelemetry
|
|
from browser_use.tools.registry.service import Registry
|
|
from browser_use.tools.service import Tools
|
|
from browser_use.utils import create_task_with_error_handling, get_browser_use_version
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Import MCP SDK
|
|
from mcp import ClientSession, StdioServerParameters, types
|
|
from mcp.client.stdio import stdio_client
|
|
|
|
MCP_AVAILABLE = True
|
|
|
|
|
|
class MCPClient:
|
|
"""Client for connecting to MCP servers and exposing their tools as browser-use actions."""
|
|
|
|
def __init__(
|
|
self,
|
|
server_name: str,
|
|
command: str,
|
|
args: list[str] | None = None,
|
|
env: dict[str, str] | None = None,
|
|
):
|
|
"""Initialize MCP client.
|
|
|
|
Args:
|
|
server_name: Name of the MCP server (for logging and identification)
|
|
command: Command to start the MCP server (e.g., "npx", "python")
|
|
args: Arguments for the command (e.g., ["@playwright/mcp@latest"])
|
|
env: Environment variables for the server process
|
|
"""
|
|
self.server_name = server_name
|
|
self.command = command
|
|
self.args = args or []
|
|
self.env = env
|
|
|
|
self.session: ClientSession | None = None
|
|
self._stdio_task = None
|
|
self._read_stream = None
|
|
self._write_stream = None
|
|
self._tools: dict[str, types.Tool] = {}
|
|
self._registered_actions: set[str] = set()
|
|
self._connected = False
|
|
self._disconnect_event = asyncio.Event()
|
|
self._telemetry = ProductTelemetry()
|
|
|
|
async def connect(self) -> None:
|
|
"""Connect to the MCP server and discover available tools."""
|
|
if self._connected:
|
|
logger.debug(f'Already connected to {self.server_name}')
|
|
return
|
|
|
|
start_time = time.time()
|
|
error_msg = None
|
|
|
|
try:
|
|
logger.info(f"🔌 Connecting to MCP server '{self.server_name}': {self.command} {' '.join(self.args)}")
|
|
|
|
# Create server parameters
|
|
server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env)
|
|
|
|
# Start stdio client in background task
|
|
self._stdio_task = create_task_with_error_handling(
|
|
self._run_stdio_client(server_params), name='mcp_stdio_client', suppress_exceptions=True
|
|
)
|
|
|
|
# Wait for connection to be established
|
|
retries = 0
|
|
max_retries = 100 # 10 second timeout (increased for parallel test execution)
|
|
while not self._connected and retries < max_retries:
|
|
await asyncio.sleep(0.1)
|
|
retries += 1
|
|
|
|
if not self._connected:
|
|
error_msg = f"Failed to connect to MCP server '{self.server_name}' after {max_retries * 0.1} seconds"
|
|
raise RuntimeError(error_msg)
|
|
|
|
logger.info(f"📦 Discovered {len(self._tools)} tools from '{self.server_name}': {list(self._tools.keys())}")
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
raise
|
|
finally:
|
|
# Capture telemetry for connect action
|
|
duration = time.time() - start_time
|
|
self._telemetry.capture(
|
|
MCPClientTelemetryEvent(
|
|
server_name=self.server_name,
|
|
command=self.command,
|
|
tools_discovered=len(self._tools),
|
|
version=get_browser_use_version(),
|
|
action='connect',
|
|
duration_seconds=duration,
|
|
error_message=error_msg,
|
|
)
|
|
)
|
|
|
|
async def _run_stdio_client(self, server_params: StdioServerParameters):
|
|
"""Run the stdio client connection in a background task."""
|
|
try:
|
|
async with stdio_client(server_params) as (read_stream, write_stream):
|
|
self._read_stream = read_stream
|
|
self._write_stream = write_stream
|
|
|
|
# Create and initialize session
|
|
async with ClientSession(read_stream, write_stream) as session:
|
|
self.session = session
|
|
|
|
# Initialize the connection
|
|
await session.initialize()
|
|
|
|
# Discover available tools
|
|
tools_response = await session.list_tools()
|
|
self._tools = {tool.name: tool for tool in tools_response.tools}
|
|
|
|
# Mark as connected
|
|
self._connected = True
|
|
|
|
# Keep the connection alive until disconnect is called
|
|
await self._disconnect_event.wait()
|
|
|
|
except Exception as e:
|
|
logger.error(f'MCP server connection error: {e}')
|
|
self._connected = False
|
|
raise
|
|
finally:
|
|
self._connected = False
|
|
self.session = None
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the MCP server."""
|
|
if not self._connected:
|
|
return
|
|
|
|
start_time = time.time()
|
|
error_msg = None
|
|
|
|
try:
|
|
logger.info(f"🔌 Disconnecting from MCP server '{self.server_name}'")
|
|
|
|
# Signal disconnect
|
|
self._connected = False
|
|
self._disconnect_event.set()
|
|
|
|
# Wait for stdio task to finish
|
|
if self._stdio_task:
|
|
try:
|
|
await asyncio.wait_for(self._stdio_task, timeout=2.0)
|
|
except TimeoutError:
|
|
logger.warning(f"Timeout waiting for MCP server '{self.server_name}' to disconnect")
|
|
self._stdio_task.cancel()
|
|
try:
|
|
await self._stdio_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
self._tools.clear()
|
|
self._registered_actions.clear()
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error disconnecting from MCP server: {e}')
|
|
finally:
|
|
# Capture telemetry for disconnect action
|
|
duration = time.time() - start_time
|
|
self._telemetry.capture(
|
|
MCPClientTelemetryEvent(
|
|
server_name=self.server_name,
|
|
command=self.command,
|
|
tools_discovered=0, # Tools cleared on disconnect
|
|
version=get_browser_use_version(),
|
|
action='disconnect',
|
|
duration_seconds=duration,
|
|
error_message=error_msg,
|
|
)
|
|
)
|
|
self._telemetry.flush()
|
|
|
|
async def register_to_tools(
|
|
self,
|
|
tools: Tools,
|
|
tool_filter: list[str] | None = None,
|
|
prefix: str | None = None,
|
|
) -> None:
|
|
"""Register MCP tools as actions in the browser-use tools.
|
|
|
|
Args:
|
|
tools: Browser-use tools to register actions to
|
|
tool_filter: Optional list of tool names to register (None = all tools)
|
|
prefix: Optional prefix to add to action names (e.g., "playwright_")
|
|
"""
|
|
if not self._connected:
|
|
await self.connect()
|
|
|
|
registry = tools.registry
|
|
|
|
for tool_name, tool in self._tools.items():
|
|
# Skip if not in filter
|
|
if tool_filter and tool_name not in tool_filter:
|
|
continue
|
|
|
|
# Apply prefix if specified
|
|
action_name = f'{prefix}{tool_name}' if prefix else tool_name
|
|
|
|
# Skip if already registered
|
|
if action_name in self._registered_actions:
|
|
continue
|
|
|
|
# Register the tool as an action
|
|
self._register_tool_as_action(registry, action_name, tool)
|
|
self._registered_actions.add(action_name)
|
|
|
|
logger.info(f"✅ Registered {len(self._registered_actions)} MCP tools from '{self.server_name}' as browser-use actions")
|
|
|
|
def _register_tool_as_action(self, registry: Registry, action_name: str, tool: Any) -> None:
|
|
"""Register a single MCP tool as a browser-use action.
|
|
|
|
Args:
|
|
registry: Browser-use registry to register action to
|
|
action_name: Name for the registered action
|
|
tool: MCP Tool object with schema information
|
|
"""
|
|
# Parse tool parameters to create Pydantic model
|
|
param_fields = {}
|
|
|
|
if tool.inputSchema:
|
|
# MCP tools use JSON Schema for parameters
|
|
properties = tool.inputSchema.get('properties', {})
|
|
required = set(tool.inputSchema.get('required', []))
|
|
|
|
for param_name, param_schema in properties.items():
|
|
# Convert JSON Schema type to Python type
|
|
param_type = self._json_schema_to_python_type(param_schema, f'{action_name}_{param_name}')
|
|
|
|
# Determine if field is required and handle defaults
|
|
if param_name in required:
|
|
default = ... # Required field
|
|
else:
|
|
# Optional field - make type optional and handle default
|
|
param_type = param_type | None
|
|
if 'default' in param_schema:
|
|
default = param_schema['default']
|
|
else:
|
|
default = None
|
|
|
|
# Add field with description if available
|
|
field_kwargs = {}
|
|
if 'description' in param_schema:
|
|
field_kwargs['description'] = param_schema['description']
|
|
|
|
param_fields[param_name] = (param_type, Field(default, **field_kwargs))
|
|
|
|
# Create Pydantic model for the tool parameters
|
|
if param_fields:
|
|
# Create a BaseModel class with proper configuration
|
|
class ConfiguredBaseModel(BaseModel):
|
|
model_config = ConfigDict(extra='forbid', validate_by_name=True, validate_by_alias=True)
|
|
|
|
param_model = create_model(f'{action_name}_Params', __base__=ConfiguredBaseModel, **param_fields)
|
|
else:
|
|
# No parameters - create empty model
|
|
param_model = None
|
|
|
|
# Determine if this is a browser-specific tool
|
|
is_browser_tool = tool.name.startswith('browser_') or 'page' in tool.name.lower()
|
|
|
|
# Set up action filters
|
|
domains = None
|
|
# Note: page_filter has been removed since we no longer use Page objects
|
|
# Browser tools filtering would need to be done via domain filters instead
|
|
|
|
# Create async wrapper function for the MCP tool
|
|
# Need to define function with explicit parameters to satisfy registry validation
|
|
if param_model:
|
|
# Type 1: Function takes param model as first parameter
|
|
async def mcp_action_wrapper(params: param_model) -> ActionResult: # type: ignore[no-redef]
|
|
"""Wrapper function that calls the MCP tool."""
|
|
if not self.session or not self._connected:
|
|
return ActionResult(error=f"MCP server '{self.server_name}' not connected", success=False)
|
|
|
|
# Convert pydantic model to dict for MCP call
|
|
tool_params = params.model_dump(exclude_none=True)
|
|
|
|
logger.debug(f"🔧 Calling MCP tool '{tool.name}' with params: {tool_params}")
|
|
|
|
start_time = time.time()
|
|
error_msg = None
|
|
|
|
try:
|
|
# Call the MCP tool
|
|
result = await self.session.call_tool(tool.name, tool_params)
|
|
|
|
# Convert MCP result to ActionResult
|
|
extracted_content = self._format_mcp_result(result)
|
|
|
|
return ActionResult(
|
|
extracted_content=extracted_content,
|
|
long_term_memory=f"Used MCP tool '{tool.name}' from {self.server_name}",
|
|
include_extracted_content_only_once=True,
|
|
)
|
|
|
|
except Exception as e:
|
|
error_msg = f"MCP tool '{tool.name}' failed: {str(e)}"
|
|
logger.error(error_msg)
|
|
return ActionResult(error=error_msg, success=False)
|
|
finally:
|
|
# Capture telemetry for tool call
|
|
duration = time.time() - start_time
|
|
self._telemetry.capture(
|
|
MCPClientTelemetryEvent(
|
|
server_name=self.server_name,
|
|
command=self.command,
|
|
tools_discovered=len(self._tools),
|
|
version=get_browser_use_version(),
|
|
action='tool_call',
|
|
tool_name=tool.name,
|
|
duration_seconds=duration,
|
|
error_message=error_msg,
|
|
)
|
|
)
|
|
else:
|
|
# No parameters - empty function signature
|
|
async def mcp_action_wrapper() -> ActionResult: # type: ignore[no-redef]
|
|
"""Wrapper function that calls the MCP tool."""
|
|
if not self.session or not self._connected:
|
|
return ActionResult(error=f"MCP server '{self.server_name}' not connected", success=False)
|
|
|
|
logger.debug(f"🔧 Calling MCP tool '{tool.name}' with no params")
|
|
|
|
start_time = time.time()
|
|
error_msg = None
|
|
|
|
try:
|
|
# Call the MCP tool with empty params
|
|
result = await self.session.call_tool(tool.name, {})
|
|
|
|
# Convert MCP result to ActionResult
|
|
extracted_content = self._format_mcp_result(result)
|
|
|
|
return ActionResult(
|
|
extracted_content=extracted_content,
|
|
long_term_memory=f"Used MCP tool '{tool.name}' from {self.server_name}",
|
|
include_extracted_content_only_once=True,
|
|
)
|
|
|
|
except Exception as e:
|
|
error_msg = f"MCP tool '{tool.name}' failed: {str(e)}"
|
|
logger.error(error_msg)
|
|
return ActionResult(error=error_msg, success=False)
|
|
finally:
|
|
# Capture telemetry for tool call
|
|
duration = time.time() - start_time
|
|
self._telemetry.capture(
|
|
MCPClientTelemetryEvent(
|
|
server_name=self.server_name,
|
|
command=self.command,
|
|
tools_discovered=len(self._tools),
|
|
version=get_browser_use_version(),
|
|
action='tool_call',
|
|
tool_name=tool.name,
|
|
duration_seconds=duration,
|
|
error_message=error_msg,
|
|
)
|
|
)
|
|
|
|
# Set function metadata for better debugging
|
|
mcp_action_wrapper.__name__ = action_name
|
|
mcp_action_wrapper.__qualname__ = f'mcp.{self.server_name}.{action_name}'
|
|
|
|
# Register the action with browser-use
|
|
description = tool.description or f'MCP tool from {self.server_name}: {tool.name}'
|
|
|
|
# Use the registry's action decorator
|
|
registry.action(description=description, param_model=param_model, domains=domains)(mcp_action_wrapper)
|
|
|
|
logger.debug(f"✅ Registered MCP tool '{tool.name}' as action '{action_name}'")
|
|
|
|
def _format_mcp_result(self, result: Any) -> str:
|
|
"""Format MCP tool result into a string for ActionResult.
|
|
|
|
Args:
|
|
result: Raw result from MCP tool call
|
|
|
|
Returns:
|
|
Formatted string representation of the result
|
|
"""
|
|
# Handle different MCP result formats
|
|
if hasattr(result, 'content'):
|
|
# Structured content response
|
|
if isinstance(result.content, list):
|
|
# Multiple content items
|
|
parts = []
|
|
for item in result.content:
|
|
if hasattr(item, 'text'):
|
|
parts.append(item.text)
|
|
elif hasattr(item, 'type') and item.type == 'text':
|
|
parts.append(str(item))
|
|
else:
|
|
parts.append(str(item))
|
|
return '\n'.join(parts)
|
|
else:
|
|
return str(result.content)
|
|
elif isinstance(result, list):
|
|
# List of content items
|
|
parts = []
|
|
for item in result:
|
|
if hasattr(item, 'text'):
|
|
parts.append(item.text)
|
|
else:
|
|
parts.append(str(item))
|
|
return '\n'.join(parts)
|
|
else:
|
|
# Direct result or unknown format
|
|
return str(result)
|
|
|
|
def _json_schema_to_python_type(self, schema: dict, model_name: str = 'NestedModel') -> Any:
|
|
"""Convert JSON Schema type to Python type.
|
|
|
|
Args:
|
|
schema: JSON Schema definition
|
|
model_name: Name for nested models
|
|
|
|
Returns:
|
|
Python type corresponding to the schema
|
|
"""
|
|
json_type = schema.get('type', 'string')
|
|
|
|
# Basic type mapping
|
|
type_mapping = {
|
|
'string': str,
|
|
'number': float,
|
|
'integer': int,
|
|
'boolean': bool,
|
|
'array': list,
|
|
'null': type(None),
|
|
}
|
|
|
|
# Handle enums (they're still strings)
|
|
if 'enum' in schema:
|
|
return str
|
|
|
|
# Handle objects with nested properties
|
|
if json_type == 'object':
|
|
properties = schema.get('properties', {})
|
|
if properties:
|
|
# Create nested pydantic model for objects with properties
|
|
nested_fields = {}
|
|
required_fields = set(schema.get('required', []))
|
|
|
|
for prop_name, prop_schema in properties.items():
|
|
# Recursively process nested properties
|
|
prop_type = self._json_schema_to_python_type(prop_schema, f'{model_name}_{prop_name}')
|
|
|
|
# Determine if field is required and handle defaults
|
|
if prop_name in required_fields:
|
|
default = ... # Required field
|
|
else:
|
|
# Optional field - make type optional and handle default
|
|
prop_type = prop_type | None
|
|
if 'default' in prop_schema:
|
|
default = prop_schema['default']
|
|
else:
|
|
default = None
|
|
|
|
# Add field with description if available
|
|
field_kwargs = {}
|
|
if 'description' in prop_schema:
|
|
field_kwargs['description'] = prop_schema['description']
|
|
|
|
nested_fields[prop_name] = (prop_type, Field(default, **field_kwargs))
|
|
|
|
# Create a BaseModel class with proper configuration
|
|
class ConfiguredBaseModel(BaseModel):
|
|
model_config = ConfigDict(extra='forbid', validate_by_name=True, validate_by_alias=True)
|
|
|
|
try:
|
|
# Create and return nested pydantic model
|
|
return create_model(model_name, __base__=ConfiguredBaseModel, **nested_fields)
|
|
except Exception as e:
|
|
logger.error(f'Failed to create nested model {model_name}: {e}')
|
|
logger.debug(f'Fields: {nested_fields}')
|
|
# Fallback to basic dict if model creation fails
|
|
return dict
|
|
else:
|
|
# Object without properties - just return dict
|
|
return dict
|
|
|
|
# Handle arrays with specific item types
|
|
if json_type == 'array':
|
|
if 'items' in schema:
|
|
# Get the item type recursively
|
|
item_type = self._json_schema_to_python_type(schema['items'], f'{model_name}_item')
|
|
# Return properly typed list
|
|
return list[item_type]
|
|
else:
|
|
# Array without item type specification
|
|
return list
|
|
|
|
# Get base type for non-object types
|
|
base_type = type_mapping.get(json_type, str)
|
|
|
|
# Handle nullable/optional types
|
|
if schema.get('nullable', False) or json_type == 'null':
|
|
return base_type | None
|
|
|
|
return base_type
|
|
|
|
async def __aenter__(self):
|
|
"""Async context manager entry."""
|
|
await self.connect()
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
"""Async context manager exit."""
|
|
await self.disconnect()
|