mirror of
https://github.com/browser-use/browser-use
synced 2026-05-06 17:52:15 +02:00
244 lines
7.2 KiB
Python
244 lines
7.2 KiB
Python
import logging
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
|
|
from browser_use.agent.prompts import AgentMessagePrompt
|
|
from browser_use.agent.service import Agent
|
|
from browser_use.browser.views import BrowserStateSummary, TabInfo
|
|
from browser_use.dom.views import DOMSelectorMap, EnhancedDOMTreeNode, NodeType, SerializedDOMState, SimplifiedNode
|
|
from browser_use.filesystem.file_system import FileSystem
|
|
from browser_use.llm.anthropic.chat import ChatAnthropic
|
|
from browser_use.llm.azure.chat import ChatAzureOpenAI
|
|
from browser_use.llm.base import BaseChatModel
|
|
from browser_use.llm.google.chat import ChatGoogle
|
|
from browser_use.llm.groq.chat import ChatGroq
|
|
|
|
# Optional OCI import
|
|
try:
|
|
from browser_use.llm.oci_raw.chat import ChatOCIRaw
|
|
|
|
OCI_AVAILABLE = True
|
|
except ImportError:
|
|
ChatOCIRaw = None
|
|
OCI_AVAILABLE = False
|
|
from browser_use.llm.openai.chat import ChatOpenAI
|
|
|
|
# Set logging level to INFO for this module
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
def _check_oci_credentials() -> bool:
|
|
"""Check if OCI credentials are available."""
|
|
if not OCI_AVAILABLE:
|
|
return False
|
|
try:
|
|
import oci
|
|
|
|
oci.config.from_file('~/.oci/config', 'DEFAULT')
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def create_mock_state_message(temp_dir: str):
|
|
"""Create a mock state message with a single clickable element."""
|
|
|
|
# Create a mock DOM element with a single clickable button
|
|
mock_button = EnhancedDOMTreeNode(
|
|
node_id=1,
|
|
backend_node_id=1,
|
|
node_type=NodeType.ELEMENT_NODE,
|
|
node_name='button',
|
|
node_value='Click Me',
|
|
attributes={'id': 'test-button'},
|
|
is_scrollable=False,
|
|
is_visible=True,
|
|
absolute_position=None,
|
|
session_id=None,
|
|
target_id='ABCD1234ABCD1234ABCD1234ABCD1234ABCD1234',
|
|
frame_id=None,
|
|
content_document=None,
|
|
shadow_root_type=None,
|
|
shadow_roots=None,
|
|
parent_node=None,
|
|
children_nodes=None,
|
|
ax_node=None,
|
|
snapshot_node=None,
|
|
)
|
|
|
|
# Create selector map (keyed by backend_node_id)
|
|
selector_map: DOMSelectorMap = {mock_button.backend_node_id: mock_button}
|
|
|
|
# Create mock tab info with proper target_id
|
|
mock_tab = TabInfo(
|
|
target_id='ABCD1234ABCD1234ABCD1234ABCD1234ABCD1234',
|
|
url='https://example.com',
|
|
title='Test Page',
|
|
)
|
|
|
|
dom_state = SerializedDOMState(
|
|
_root=SimplifiedNode(
|
|
original_node=mock_button,
|
|
children=[],
|
|
should_display=True,
|
|
is_interactive=True,
|
|
),
|
|
selector_map=selector_map,
|
|
)
|
|
|
|
# Create mock browser state with required selector_map
|
|
mock_browser_state = BrowserStateSummary(
|
|
dom_state=dom_state, # Using the actual DOM element
|
|
url='https://example.com',
|
|
title='Test Page',
|
|
tabs=[mock_tab],
|
|
screenshot='', # Empty screenshot
|
|
pixels_above=0,
|
|
pixels_below=0,
|
|
)
|
|
|
|
# Create file system using the provided temp directory
|
|
mock_file_system = FileSystem(temp_dir)
|
|
|
|
# Create the agent message prompt
|
|
agent_prompt = AgentMessagePrompt(
|
|
browser_state_summary=mock_browser_state,
|
|
file_system=mock_file_system, # Now using actual FileSystem instance
|
|
agent_history_description='', # Empty history
|
|
read_state_description='', # Empty read state
|
|
task='Click the button on the page',
|
|
include_attributes=['id'],
|
|
step_info=None,
|
|
page_filtered_actions=None,
|
|
max_clickable_elements_length=40000,
|
|
sensitive_data=None,
|
|
)
|
|
|
|
# Override the clickable_elements_to_string method to return our simple element
|
|
dom_state.llm_representation = lambda include_attributes=None: '[1]<button id="test-button">Click Me</button>'
|
|
|
|
# Get the formatted message
|
|
message = agent_prompt.get_user_message(use_vision=False)
|
|
|
|
return message
|
|
|
|
|
|
# Pytest parameterized version
|
|
@pytest.mark.parametrize(
|
|
'llm_class,model_name',
|
|
[
|
|
(ChatGroq, 'meta-llama/llama-4-maverick-17b-128e-instruct'),
|
|
(ChatGoogle, 'gemini-2.0-flash-exp'),
|
|
(ChatOpenAI, 'gpt-4.1-mini'),
|
|
(ChatAnthropic, 'claude-3-5-sonnet-latest'),
|
|
(ChatAzureOpenAI, 'gpt-4.1-mini'),
|
|
pytest.param(
|
|
ChatOCIRaw,
|
|
{
|
|
'model_id': os.getenv('OCI_MODEL_ID', 'placeholder'),
|
|
'service_endpoint': os.getenv(
|
|
'OCI_SERVICE_ENDPOINT', 'https://inference.generativeai.us-chicago-1.oci.oraclecloud.com'
|
|
),
|
|
'compartment_id': os.getenv('OCI_COMPARTMENT_ID', 'placeholder'),
|
|
'provider': 'meta',
|
|
'temperature': 0.7,
|
|
'max_tokens': 800,
|
|
'frequency_penalty': 0.0,
|
|
'presence_penalty': 0.0,
|
|
'top_p': 0.9,
|
|
'auth_type': 'API_KEY',
|
|
'auth_profile': 'DEFAULT',
|
|
},
|
|
marks=pytest.mark.skipif(
|
|
not _check_oci_credentials() or not os.getenv('OCI_MODEL_ID') or not os.getenv('OCI_COMPARTMENT_ID'),
|
|
reason='OCI credentials or environment variables not available',
|
|
),
|
|
),
|
|
],
|
|
)
|
|
async def test_single_step_parametrized(llm_class, model_name):
|
|
"""Test single step with different LLM providers using pytest parametrize."""
|
|
if isinstance(model_name, dict):
|
|
# Handle ChatOCIRaw which requires keyword arguments
|
|
llm = llm_class(**model_name)
|
|
else:
|
|
llm = llm_class(model=model_name)
|
|
|
|
agent = Agent(task='Click the button on the page', llm=llm)
|
|
|
|
# Create temporary directory that will stay alive during the test
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Create mock state message
|
|
mock_message = create_mock_state_message(temp_dir)
|
|
|
|
agent.message_manager._set_message_with_type(mock_message, 'state')
|
|
|
|
messages = agent.message_manager.get_messages()
|
|
|
|
# Test with simple question
|
|
response = await llm.ainvoke(messages, agent.AgentOutput)
|
|
|
|
# Additional validation for OCI Raw
|
|
if ChatOCIRaw is not None and isinstance(llm, ChatOCIRaw):
|
|
# Verify OCI Raw generates proper Agent actions
|
|
assert response.completion.action is not None
|
|
assert len(response.completion.action) > 0
|
|
|
|
# Basic assertions to ensure response is valid
|
|
assert response.completion is not None
|
|
assert response.usage is not None
|
|
assert response.usage.total_tokens > 0
|
|
|
|
|
|
async def test_single_step():
|
|
"""Original test function that tests all models in a loop."""
|
|
# Create a list of models to test
|
|
models: list[BaseChatModel] = [
|
|
ChatGroq(model='meta-llama/llama-4-maverick-17b-128e-instruct'),
|
|
ChatGoogle(model='gemini-2.0-flash-exp'),
|
|
ChatOpenAI(model='gpt-4.1'),
|
|
ChatAnthropic(model='claude-3-5-sonnet-latest'), # Using haiku for cost efficiency
|
|
ChatAzureOpenAI(model='gpt-4o-mini'),
|
|
]
|
|
|
|
for llm in models:
|
|
print(f'\n{"=" * 60}')
|
|
print(f'Testing with model: {llm.provider} - {llm.model}')
|
|
print(f'{"=" * 60}\n')
|
|
|
|
agent = Agent(task='Click the button on the page', llm=llm)
|
|
|
|
# Create temporary directory that will stay alive during the test
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Create mock state message
|
|
mock_message = create_mock_state_message(temp_dir)
|
|
|
|
# Print the mock message content to see what it looks like
|
|
print('Mock state message:')
|
|
print(mock_message.content)
|
|
print('\n' + '=' * 50 + '\n')
|
|
|
|
agent.message_manager._set_message_with_type(mock_message, 'state')
|
|
|
|
messages = agent.message_manager.get_messages()
|
|
|
|
# Test with simple question
|
|
try:
|
|
response = await llm.ainvoke(messages, agent.AgentOutput)
|
|
logger.info(f'Response from {llm.provider}: {response.completion}')
|
|
logger.info(f'Actions: {str(response.completion.action)}')
|
|
|
|
except Exception as e:
|
|
logger.error(f'Error with {llm.provider}: {type(e).__name__}: {str(e)}')
|
|
|
|
print(f'\n{"=" * 60}\n')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import asyncio
|
|
|
|
asyncio.run(test_single_step())
|