mirror of
https://github.com/browser-use/browser-use
synced 2026-05-06 17:52:15 +02:00
300 lines
10 KiB
Python
300 lines
10 KiB
Python
import os
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from browser_use.llm import ChatAnthropic, ChatGoogle, ChatGroq, ChatOpenAI, ChatOpenRouter
|
|
from browser_use.llm.messages import ContentPartTextParam
|
|
|
|
# Optional OCI import
|
|
try:
|
|
from examples.models.oci_models import xai_llm
|
|
|
|
OCI_MODELS_AVAILABLE = True
|
|
except ImportError:
|
|
xai_llm = None
|
|
OCI_MODELS_AVAILABLE = False
|
|
|
|
|
|
class CapitalResponse(BaseModel):
|
|
"""Structured response for capital question"""
|
|
|
|
country: str
|
|
capital: str
|
|
|
|
|
|
class TestChatModels:
|
|
from browser_use.llm.messages import (
|
|
AssistantMessage,
|
|
BaseMessage,
|
|
SystemMessage,
|
|
UserMessage,
|
|
)
|
|
|
|
"""Test suite for all chat model implementations"""
|
|
|
|
# Test Constants
|
|
SYSTEM_MESSAGE = SystemMessage(content=[ContentPartTextParam(text='You are a helpful assistant.', type='text')])
|
|
FRANCE_QUESTION = UserMessage(content='What is the capital of France? Answer in one word.')
|
|
FRANCE_ANSWER = AssistantMessage(content='Paris')
|
|
GERMANY_QUESTION = UserMessage(content='What is the capital of Germany? Answer in one word.')
|
|
|
|
# Expected values
|
|
EXPECTED_GERMANY_CAPITAL = 'berlin'
|
|
EXPECTED_FRANCE_COUNTRY = 'france'
|
|
EXPECTED_FRANCE_CAPITAL = 'paris'
|
|
|
|
# Test messages for conversation
|
|
CONVERSATION_MESSAGES: list[BaseMessage] = [
|
|
SYSTEM_MESSAGE,
|
|
FRANCE_QUESTION,
|
|
FRANCE_ANSWER,
|
|
GERMANY_QUESTION,
|
|
]
|
|
|
|
# Test messages for structured output
|
|
STRUCTURED_MESSAGES: list[BaseMessage] = [UserMessage(content='What is the capital of France?')]
|
|
|
|
# OpenAI Tests
|
|
@pytest.fixture
|
|
def openrouter_chat(self):
|
|
"""Provides an initialized ChatOpenRouter client for tests."""
|
|
if not os.getenv('OPENROUTER_API_KEY'):
|
|
pytest.skip('OPENROUTER_API_KEY not set')
|
|
return ChatOpenRouter(model='openai/gpt-4o-mini', api_key=os.getenv('OPENROUTER_API_KEY'), temperature=0)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_ainvoke_normal(self):
|
|
"""Test normal text response from OpenAI"""
|
|
# Skip if no API key
|
|
if not os.getenv('OPENAI_API_KEY'):
|
|
pytest.skip('OPENAI_API_KEY not set')
|
|
|
|
chat = ChatOpenAI(model='gpt-4o-mini', temperature=0)
|
|
response = await chat.ainvoke(self.CONVERSATION_MESSAGES)
|
|
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, str)
|
|
assert self.EXPECTED_GERMANY_CAPITAL in completion.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_ainvoke_structured(self):
|
|
"""Test structured output from OpenAI"""
|
|
# Skip if no API key
|
|
if not os.getenv('OPENAI_API_KEY'):
|
|
pytest.skip('OPENAI_API_KEY not set')
|
|
|
|
chat = ChatOpenAI(model='gpt-4o-mini', temperature=0)
|
|
response = await chat.ainvoke(self.STRUCTURED_MESSAGES, output_format=CapitalResponse)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, CapitalResponse)
|
|
assert completion.country.lower() == self.EXPECTED_FRANCE_COUNTRY
|
|
assert completion.capital.lower() == self.EXPECTED_FRANCE_CAPITAL
|
|
|
|
# Anthropic Tests
|
|
@pytest.mark.asyncio
|
|
async def test_anthropic_ainvoke_normal(self):
|
|
"""Test normal text response from Anthropic"""
|
|
# Skip if no API key
|
|
if not os.getenv('ANTHROPIC_API_KEY'):
|
|
pytest.skip('ANTHROPIC_API_KEY not set')
|
|
|
|
chat = ChatAnthropic(model='claude-3-5-haiku-latest', max_tokens=100, temperature=0)
|
|
response = await chat.ainvoke(self.CONVERSATION_MESSAGES)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, str)
|
|
assert self.EXPECTED_GERMANY_CAPITAL in completion.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_anthropic_ainvoke_structured(self):
|
|
"""Test structured output from Anthropic"""
|
|
# Skip if no API key
|
|
if not os.getenv('ANTHROPIC_API_KEY'):
|
|
pytest.skip('ANTHROPIC_API_KEY not set')
|
|
|
|
chat = ChatAnthropic(model='claude-3-5-haiku-latest', max_tokens=100, temperature=0)
|
|
response = await chat.ainvoke(self.STRUCTURED_MESSAGES, output_format=CapitalResponse)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, CapitalResponse)
|
|
assert completion.country.lower() == self.EXPECTED_FRANCE_COUNTRY
|
|
assert completion.capital.lower() == self.EXPECTED_FRANCE_CAPITAL
|
|
|
|
# Google Gemini Tests
|
|
@pytest.mark.asyncio
|
|
async def test_google_ainvoke_normal(self):
|
|
"""Test normal text response from Google Gemini"""
|
|
# Skip if no API key
|
|
if not os.getenv('GOOGLE_API_KEY'):
|
|
pytest.skip('GOOGLE_API_KEY not set')
|
|
|
|
chat = ChatGoogle(model='gemini-2.0-flash', api_key=os.getenv('GOOGLE_API_KEY'), temperature=0)
|
|
response = await chat.ainvoke(self.CONVERSATION_MESSAGES)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, str)
|
|
assert self.EXPECTED_GERMANY_CAPITAL in completion.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_google_ainvoke_structured(self):
|
|
"""Test structured output from Google Gemini"""
|
|
# Skip if no API key
|
|
if not os.getenv('GOOGLE_API_KEY'):
|
|
pytest.skip('GOOGLE_API_KEY not set')
|
|
|
|
chat = ChatGoogle(model='gemini-2.0-flash', api_key=os.getenv('GOOGLE_API_KEY'), temperature=0)
|
|
response = await chat.ainvoke(self.STRUCTURED_MESSAGES, output_format=CapitalResponse)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, CapitalResponse)
|
|
assert completion.country.lower() == self.EXPECTED_FRANCE_COUNTRY
|
|
assert completion.capital.lower() == self.EXPECTED_FRANCE_CAPITAL
|
|
|
|
# Google Gemini with Vertex AI Tests
|
|
@pytest.mark.asyncio
|
|
async def test_google_vertex_ainvoke_normal(self):
|
|
"""Test normal text response from Google Gemini via Vertex AI"""
|
|
# Skip if no project ID
|
|
if not os.getenv('GOOGLE_CLOUD_PROJECT'):
|
|
pytest.skip('GOOGLE_CLOUD_PROJECT not set')
|
|
|
|
chat = ChatGoogle(
|
|
model='gemini-2.0-flash',
|
|
vertexai=True,
|
|
project=os.getenv('GOOGLE_CLOUD_PROJECT'),
|
|
location='us-central1',
|
|
temperature=0,
|
|
)
|
|
response = await chat.ainvoke(self.CONVERSATION_MESSAGES)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, str)
|
|
assert self.EXPECTED_GERMANY_CAPITAL in completion.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_google_vertex_ainvoke_structured(self):
|
|
"""Test structured output from Google Gemini via Vertex AI"""
|
|
# Skip if no project ID
|
|
if not os.getenv('GOOGLE_CLOUD_PROJECT'):
|
|
pytest.skip('GOOGLE_CLOUD_PROJECT not set')
|
|
|
|
chat = ChatGoogle(
|
|
model='gemini-2.0-flash',
|
|
vertexai=True,
|
|
project=os.getenv('GOOGLE_CLOUD_PROJECT'),
|
|
location='us-central1',
|
|
temperature=0,
|
|
)
|
|
response = await chat.ainvoke(self.STRUCTURED_MESSAGES, output_format=CapitalResponse)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, CapitalResponse)
|
|
assert completion.country.lower() == self.EXPECTED_FRANCE_COUNTRY
|
|
assert completion.capital.lower() == self.EXPECTED_FRANCE_CAPITAL
|
|
|
|
# Groq Tests
|
|
@pytest.mark.asyncio
|
|
async def test_groq_ainvoke_normal(self):
|
|
"""Test normal text response from Groq"""
|
|
# Skip if no API key
|
|
if not os.getenv('GROQ_API_KEY'):
|
|
pytest.skip('GROQ_API_KEY not set')
|
|
|
|
chat = ChatGroq(model='meta-llama/llama-4-maverick-17b-128e-instruct', temperature=0)
|
|
response = await chat.ainvoke(self.CONVERSATION_MESSAGES)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, str)
|
|
assert self.EXPECTED_GERMANY_CAPITAL in completion.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_groq_ainvoke_structured(self):
|
|
"""Test structured output from Groq"""
|
|
# Skip if no API key
|
|
if not os.getenv('GROQ_API_KEY'):
|
|
pytest.skip('GROQ_API_KEY not set')
|
|
|
|
chat = ChatGroq(model='meta-llama/llama-4-maverick-17b-128e-instruct', temperature=0)
|
|
response = await chat.ainvoke(self.STRUCTURED_MESSAGES, output_format=CapitalResponse)
|
|
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, CapitalResponse)
|
|
assert completion.country.lower() == self.EXPECTED_FRANCE_COUNTRY
|
|
assert completion.capital.lower() == self.EXPECTED_FRANCE_CAPITAL
|
|
|
|
# OpenRouter Tests
|
|
@pytest.mark.asyncio
|
|
async def test_openrouter_ainvoke_normal(self):
|
|
"""Test normal text response from OpenRouter"""
|
|
# Skip if no API key
|
|
if not os.getenv('OPENROUTER_API_KEY'):
|
|
pytest.skip('OPENROUTER_API_KEY not set')
|
|
|
|
chat = ChatOpenRouter(model='openai/gpt-4o-mini', api_key=os.getenv('OPENROUTER_API_KEY'), temperature=0)
|
|
response = await chat.ainvoke(self.CONVERSATION_MESSAGES)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, str)
|
|
assert self.EXPECTED_GERMANY_CAPITAL in completion.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openrouter_ainvoke_structured(self):
|
|
"""Test structured output from OpenRouter"""
|
|
# Skip if no API key
|
|
if not os.getenv('OPENROUTER_API_KEY'):
|
|
pytest.skip('OPENROUTER_API_KEY not set')
|
|
|
|
chat = ChatOpenRouter(model='openai/gpt-4o-mini', api_key=os.getenv('OPENROUTER_API_KEY'), temperature=0)
|
|
response = await chat.ainvoke(self.STRUCTURED_MESSAGES, output_format=CapitalResponse)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, CapitalResponse)
|
|
assert completion.country.lower() == self.EXPECTED_FRANCE_COUNTRY
|
|
assert completion.capital.lower() == self.EXPECTED_FRANCE_CAPITAL
|
|
|
|
# OCI Raw Tests
|
|
@pytest.fixture
|
|
def oci_raw_chat(self):
|
|
"""Provides an initialized ChatOCIRaw client for tests."""
|
|
# Skip if OCI models not available
|
|
if not OCI_MODELS_AVAILABLE:
|
|
pytest.skip('OCI models not available - install with pip install "browser-use[oci]"')
|
|
|
|
# Skip if OCI credentials not available - check for config file existence
|
|
try:
|
|
import oci
|
|
|
|
oci.config.from_file('~/.oci/config', 'DEFAULT')
|
|
except Exception:
|
|
pytest.skip('OCI credentials not available')
|
|
|
|
# Skip if using placeholder config
|
|
if xai_llm and hasattr(xai_llm, 'compartment_id') and 'example' in xai_llm.compartment_id.lower():
|
|
pytest.skip('OCI model using placeholder configuration - set real credentials')
|
|
|
|
return xai_llm # xai or cohere
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_oci_raw_ainvoke_normal(self, oci_raw_chat):
|
|
"""Test normal text response from OCI Raw"""
|
|
response = await oci_raw_chat.ainvoke(self.CONVERSATION_MESSAGES)
|
|
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, str)
|
|
assert self.EXPECTED_GERMANY_CAPITAL in completion.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_oci_raw_ainvoke_structured(self, oci_raw_chat):
|
|
"""Test structured output from OCI Raw"""
|
|
response = await oci_raw_chat.ainvoke(self.STRUCTURED_MESSAGES, output_format=CapitalResponse)
|
|
completion = response.completion
|
|
|
|
assert isinstance(completion, CapitalResponse)
|
|
assert completion.country.lower() == self.EXPECTED_FRANCE_COUNTRY
|
|
assert completion.capital.lower() == self.EXPECTED_FRANCE_CAPITAL
|