Files
browser-use/browser_use/llm/google/chat.py
2026-03-23 14:13:08 +08:00

606 lines
22 KiB
Python

import asyncio
import json
import logging
import random
import time
from dataclasses import dataclass, field
from typing import Any, Literal, TypeVar, overload
from google import genai
from google.auth.credentials import Credentials
from google.genai import types
from google.genai.types import MediaModality
from pydantic import BaseModel
from browser_use.llm.base import BaseChatModel
from browser_use.llm.exceptions import ModelProviderError
from browser_use.llm.google.serializer import GoogleMessageSerializer
from browser_use.llm.messages import BaseMessage
from browser_use.llm.schema import SchemaOptimizer
from browser_use.llm.views import ChatInvokeCompletion, ChatInvokeUsage
T = TypeVar('T', bound=BaseModel)
VerifiedGeminiModels = Literal[
'gemini-2.0-flash',
'gemini-2.0-flash-exp',
'gemini-2.0-flash-lite-preview-02-05',
'Gemini-2.0-exp',
'gemini-2.5-flash',
'gemini-2.5-flash-lite',
'gemini-flash-latest',
'gemini-flash-lite-latest',
'gemini-2.5-pro',
'gemini-3-pro-preview',
'gemini-3-flash-preview',
'gemma-3-27b-it',
'gemma-3-4b',
'gemma-3-12b',
'gemma-3n-e2b',
'gemma-3n-e4b',
]
@dataclass
class ChatGoogle(BaseChatModel):
"""
A wrapper around Google's Gemini chat model using the genai client.
This class accepts all genai.Client parameters while adding model,
temperature, and config parameters for the LLM interface.
Args:
model: The Gemini model to use
temperature: Temperature for response generation
config: Additional configuration parameters to pass to generate_content
(e.g., tools, safety_settings, etc.).
api_key: Google API key
vertexai: Whether to use Vertex AI
credentials: Google credentials object
project: Google Cloud project ID
location: Google Cloud location
http_options: HTTP options for the client
include_system_in_user: If True, system messages are included in the first user message
supports_structured_output: If True, uses native JSON mode; if False, uses prompt-based fallback
max_retries: Number of retries for retryable errors (default: 5)
retryable_status_codes: List of HTTP status codes to retry on (default: [429, 500, 502, 503, 504])
retry_base_delay: Base delay in seconds for exponential backoff (default: 1.0)
retry_max_delay: Maximum delay in seconds between retries (default: 60.0)
Example:
from google.genai import types
llm = ChatGoogle(
model='gemini-2.0-flash-exp',
config={
'tools': [types.Tool(code_execution=types.ToolCodeExecution())]
},
max_retries=5,
retryable_status_codes=[429, 500, 502, 503, 504],
retry_base_delay=1.0,
retry_max_delay=60.0,
)
"""
# Model configuration
model: VerifiedGeminiModels | str
temperature: float | None = None
top_p: float | None = None
seed: int | None = None
thinking_budget: int | None = None # for Gemini 2.5: -1 for dynamic (default), 0 disables, or token count
thinking_level: Literal['minimal', 'low', 'medium', 'high'] | None = (
None # for Gemini 3: Pro supports low/high, Flash supports all levels
)
max_output_tokens: int | None = 8096
config: types.GenerateContentConfigDict | None = None
include_system_in_user: bool = False
supports_structured_output: bool = True # New flag
max_retries: int = 5 # Number of retries for retryable errors
retryable_status_codes: list[int] = field(default_factory=lambda: [429, 500, 502, 503, 504]) # Status codes to retry on
retry_base_delay: float = 1.0 # Base delay in seconds for exponential backoff
retry_max_delay: float = 60.0 # Maximum delay in seconds between retries
# Client initialization parameters
api_key: str | None = None
vertexai: bool | None = None
credentials: Credentials | None = None
project: str | None = None
location: str | None = None
http_options: types.HttpOptions | types.HttpOptionsDict | None = None
# Internal client cache to prevent connection issues
_client: genai.Client | None = None
# Static
@property
def provider(self) -> str:
return 'google'
@property
def logger(self) -> logging.Logger:
"""Get logger for this chat instance"""
return logging.getLogger(f'browser_use.llm.google.{self.model}')
def _get_client_params(self) -> dict[str, Any]:
"""Prepare client parameters dictionary."""
# Define base client params
base_params = {
'api_key': self.api_key,
'vertexai': self.vertexai,
'credentials': self.credentials,
'project': self.project,
'location': self.location,
'http_options': self.http_options,
}
# Create client_params dict with non-None values
client_params = {k: v for k, v in base_params.items() if v is not None}
return client_params
def get_client(self) -> genai.Client:
"""
Returns a genai.Client instance.
Returns:
genai.Client: An instance of the Google genai client.
"""
if self._client is not None:
return self._client
client_params = self._get_client_params()
self._client = genai.Client(**client_params)
return self._client
@property
def name(self) -> str:
return str(self.model)
def _get_stop_reason(self, response: types.GenerateContentResponse) -> str | None:
"""Extract stop_reason from Google response."""
if hasattr(response, 'candidates') and response.candidates:
return str(response.candidates[0].finish_reason) if hasattr(response.candidates[0], 'finish_reason') else None
return None
def _get_usage(self, response: types.GenerateContentResponse) -> ChatInvokeUsage | None:
usage: ChatInvokeUsage | None = None
if response.usage_metadata is not None:
image_tokens = 0
if response.usage_metadata.prompt_tokens_details is not None:
image_tokens = sum(
detail.token_count or 0
for detail in response.usage_metadata.prompt_tokens_details
if detail.modality == MediaModality.IMAGE
)
usage = ChatInvokeUsage(
prompt_tokens=response.usage_metadata.prompt_token_count or 0,
completion_tokens=(response.usage_metadata.candidates_token_count or 0)
+ (response.usage_metadata.thoughts_token_count or 0),
total_tokens=response.usage_metadata.total_token_count or 0,
prompt_cached_tokens=response.usage_metadata.cached_content_token_count,
prompt_cache_creation_tokens=None,
prompt_image_tokens=image_tokens,
)
return usage
@overload
async def ainvoke(
self, messages: list[BaseMessage], output_format: None = None, **kwargs: Any
) -> ChatInvokeCompletion[str]: ...
@overload
async def ainvoke(self, messages: list[BaseMessage], output_format: type[T], **kwargs: Any) -> ChatInvokeCompletion[T]: ...
async def ainvoke(
self, messages: list[BaseMessage], output_format: type[T] | None = None, **kwargs: Any
) -> ChatInvokeCompletion[T] | ChatInvokeCompletion[str]:
"""
Invoke the model with the given messages.
Args:
messages: List of chat messages
output_format: Optional Pydantic model class for structured output
Returns:
Either a string response or an instance of output_format
"""
# Serialize messages to Google format with the include_system_in_user flag
contents, system_instruction = GoogleMessageSerializer.serialize_messages(
messages, include_system_in_user=self.include_system_in_user
)
# Build config dictionary starting with user-provided config
config: types.GenerateContentConfigDict = {}
if self.config:
config = self.config.copy()
# Apply model-specific configuration (these can override config)
if self.temperature is not None:
config['temperature'] = self.temperature
else:
config['temperature'] = 1.0 if 'gemini-3' in self.model else 0.5
# Add system instruction if present
if system_instruction:
config['system_instruction'] = system_instruction
if self.top_p is not None:
config['top_p'] = self.top_p
if self.seed is not None:
config['seed'] = self.seed
# Configure thinking based on model version
# Gemini 3 Pro: uses thinking_level only
# Gemini 3 Flash: supports both, defaults to thinking_budget=-1
# Gemini 2.5: uses thinking_budget only
is_gemini_3_pro = 'gemini-3-pro' in self.model
is_gemini_3_flash = 'gemini-3-flash' in self.model
if is_gemini_3_pro:
# Validate: thinking_budget should not be set for Gemini 3 Pro
if self.thinking_budget is not None:
self.logger.warning(
f'thinking_budget={self.thinking_budget} is deprecated for Gemini 3 Pro and may cause '
f'suboptimal performance. Use thinking_level instead.'
)
# Validate: minimal/medium only supported on Flash, not Pro
if self.thinking_level in ('minimal', 'medium'):
self.logger.warning(
f'thinking_level="{self.thinking_level}" is not supported for Gemini 3 Pro. '
f'Only "low" and "high" are valid. Falling back to "low".'
)
self.thinking_level = 'low'
# Default to 'low' for Gemini 3 Pro
if self.thinking_level is None:
self.thinking_level = 'low'
# Map to ThinkingLevel enum (SDK accepts string values)
level = types.ThinkingLevel(self.thinking_level.upper())
config['thinking_config'] = types.ThinkingConfigDict(thinking_level=level)
elif is_gemini_3_flash:
# Gemini 3 Flash supports both thinking_level and thinking_budget
# If user set thinking_level, use that; otherwise default to thinking_budget=-1
if self.thinking_level is not None:
level = types.ThinkingLevel(self.thinking_level.upper())
config['thinking_config'] = types.ThinkingConfigDict(thinking_level=level)
else:
if self.thinking_budget is None:
self.thinking_budget = -1
config['thinking_config'] = types.ThinkingConfigDict(thinking_budget=self.thinking_budget)
else:
# Gemini 2.5 and earlier: use thinking_budget only
if self.thinking_level is not None:
self.logger.warning(
f'thinking_level="{self.thinking_level}" is not supported for this model. '
f'Use thinking_budget instead (0 to disable, -1 for dynamic, or token count).'
)
# Default to -1 for dynamic/auto on 2.5 models
if self.thinking_budget is None and ('gemini-2.5' in self.model or 'gemini-flash' in self.model):
self.thinking_budget = -1
if self.thinking_budget is not None:
config['thinking_config'] = types.ThinkingConfigDict(thinking_budget=self.thinking_budget)
if self.max_output_tokens is not None:
config['max_output_tokens'] = self.max_output_tokens
async def _make_api_call():
start_time = time.time()
self.logger.debug(f'🚀 Starting API call to {self.model}')
try:
if output_format is None:
# Return string response
self.logger.debug('📄 Requesting text response')
response = await self.get_client().aio.models.generate_content(
model=self.model,
contents=contents, # type: ignore
config=config,
)
elapsed = time.time() - start_time
self.logger.debug(f'✅ Got text response in {elapsed:.2f}s')
# Handle case where response.text might be None
text = response.text or ''
if not text:
self.logger.warning('⚠️ Empty text response received')
usage = self._get_usage(response)
return ChatInvokeCompletion(
completion=text,
usage=usage,
stop_reason=self._get_stop_reason(response),
)
else:
# Handle structured output
if self.supports_structured_output:
# Use native JSON mode
self.logger.debug(f'🔧 Requesting structured output for {output_format.__name__}')
config['response_mime_type'] = 'application/json'
# Convert Pydantic model to Gemini-compatible schema
optimized_schema = SchemaOptimizer.create_gemini_optimized_schema(output_format)
gemini_schema = self._fix_gemini_schema(optimized_schema)
config['response_schema'] = gemini_schema
response = await self.get_client().aio.models.generate_content(
model=self.model,
contents=contents,
config=config,
)
elapsed = time.time() - start_time
self.logger.debug(f'✅ Got structured response in {elapsed:.2f}s')
usage = self._get_usage(response)
# Handle case where response.parsed might be None
if response.parsed is None:
self.logger.debug('📝 Parsing JSON from text response')
# When using response_schema, Gemini returns JSON as text
if response.text:
try:
# Handle JSON wrapped in markdown code blocks (common Gemini behavior)
text = response.text.strip()
if text.startswith('```json') and text.endswith('```'):
text = text[7:-3].strip()
self.logger.debug('🔧 Stripped ```json``` wrapper from response')
elif text.startswith('```') and text.endswith('```'):
text = text[3:-3].strip()
self.logger.debug('🔧 Stripped ``` wrapper from response')
# Parse the JSON text and validate with the Pydantic model
parsed_data = json.loads(text)
return ChatInvokeCompletion(
completion=output_format.model_validate(parsed_data),
usage=usage,
stop_reason=self._get_stop_reason(response),
)
except (json.JSONDecodeError, ValueError) as e:
self.logger.error(f'❌ Failed to parse JSON response: {str(e)}')
self.logger.debug(f'Raw response text: {response.text[:200]}...')
raise ModelProviderError(
message=f'Failed to parse or validate response {response}: {str(e)}',
status_code=500,
model=self.model,
) from e
else:
self.logger.error('❌ No response text received')
raise ModelProviderError(
message=f'No response from model {response}',
status_code=500,
model=self.model,
)
# Ensure we return the correct type
if isinstance(response.parsed, output_format):
return ChatInvokeCompletion(
completion=response.parsed,
usage=usage,
stop_reason=self._get_stop_reason(response),
)
else:
# If it's not the expected type, try to validate it
return ChatInvokeCompletion(
completion=output_format.model_validate(response.parsed),
usage=usage,
stop_reason=self._get_stop_reason(response),
)
else:
# Fallback: Request JSON in the prompt for models without native JSON mode
self.logger.debug(f'🔄 Using fallback JSON mode for {output_format.__name__}')
# Create a copy of messages to modify
modified_messages = [m.model_copy(deep=True) for m in messages]
# Add JSON instruction to the last message
if modified_messages and isinstance(modified_messages[-1].content, str):
json_instruction = f'\n\nPlease respond with a valid JSON object that matches this schema: {SchemaOptimizer.create_optimized_json_schema(output_format)}'
modified_messages[-1].content += json_instruction
# Re-serialize with modified messages
fallback_contents, fallback_system = GoogleMessageSerializer.serialize_messages(
modified_messages, include_system_in_user=self.include_system_in_user
)
# Update config with fallback system instruction if present
fallback_config = config.copy()
if fallback_system:
fallback_config['system_instruction'] = fallback_system
response = await self.get_client().aio.models.generate_content(
model=self.model,
contents=fallback_contents, # type: ignore
config=fallback_config,
)
elapsed = time.time() - start_time
self.logger.debug(f'✅ Got fallback response in {elapsed:.2f}s')
usage = self._get_usage(response)
# Try to extract JSON from the text response
if response.text:
try:
# Try to find JSON in the response
text = response.text.strip()
# Common patterns: JSON wrapped in markdown code blocks
if text.startswith('```json') and text.endswith('```'):
text = text[7:-3].strip()
elif text.startswith('```') and text.endswith('```'):
text = text[3:-3].strip()
# Parse and validate
parsed_data = json.loads(text)
return ChatInvokeCompletion(
completion=output_format.model_validate(parsed_data),
usage=usage,
stop_reason=self._get_stop_reason(response),
)
except (json.JSONDecodeError, ValueError) as e:
self.logger.error(f'❌ Failed to parse fallback JSON: {str(e)}')
self.logger.debug(f'Raw response text: {response.text[:200]}...')
raise ModelProviderError(
message=f'Model does not support JSON mode and failed to parse JSON from text response: {str(e)}',
status_code=500,
model=self.model,
) from e
else:
self.logger.error('❌ No response text in fallback mode')
raise ModelProviderError(
message='No response from model',
status_code=500,
model=self.model,
)
except Exception as e:
elapsed = time.time() - start_time
self.logger.error(f'💥 API call failed after {elapsed:.2f}s: {type(e).__name__}: {e}')
# Re-raise the exception
raise
# Retry logic for certain errors with exponential backoff
assert self.max_retries >= 1, 'max_retries must be at least 1'
for attempt in range(self.max_retries):
try:
return await _make_api_call()
except ModelProviderError as e:
# Retry if status code is in retryable list and we have attempts left
if e.status_code in self.retryable_status_codes and attempt < self.max_retries - 1:
# Exponential backoff with jitter: base_delay * 2^attempt + random jitter
delay = min(self.retry_base_delay * (2**attempt), self.retry_max_delay)
jitter = random.uniform(0, delay * 0.1) # 10% jitter
total_delay = delay + jitter
self.logger.warning(
f'⚠️ Got {e.status_code} error, retrying in {total_delay:.1f}s... (attempt {attempt + 1}/{self.max_retries})'
)
await asyncio.sleep(total_delay)
continue
# Otherwise raise
raise
except Exception as e:
# For non-ModelProviderError, wrap and raise
error_message = str(e)
status_code: int | None = None
# Try to extract status code if available
if hasattr(e, 'response'):
response_obj = getattr(e, 'response', None)
if response_obj and hasattr(response_obj, 'status_code'):
status_code = getattr(response_obj, 'status_code', None)
# Enhanced timeout error handling
if 'timeout' in error_message.lower() or 'cancelled' in error_message.lower():
if isinstance(e, asyncio.CancelledError) or 'CancelledError' in str(type(e)):
error_message = 'Gemini API request was cancelled (likely timeout). Consider: 1) Reducing input size, 2) Using a different model, 3) Checking network connectivity.'
status_code = 504
else:
status_code = 408
elif any(indicator in error_message.lower() for indicator in ['forbidden', '403']):
status_code = 403
elif any(
indicator in error_message.lower()
for indicator in ['rate limit', 'resource exhausted', 'quota exceeded', 'too many requests', '429']
):
status_code = 429
elif any(
indicator in error_message.lower()
for indicator in ['service unavailable', 'internal server error', 'bad gateway', '503', '502', '500']
):
status_code = 503
raise ModelProviderError(
message=error_message,
status_code=status_code or 502,
model=self.name,
) from e
raise RuntimeError('Retry loop completed without return or exception')
def _fix_gemini_schema(self, schema: dict[str, Any]) -> dict[str, Any]:
"""
Convert a Pydantic model to a Gemini-compatible schema.
This function removes unsupported properties like 'additionalProperties' and resolves
$ref references that Gemini doesn't support.
"""
# Handle $defs and $ref resolution
if '$defs' in schema:
defs = schema.pop('$defs')
def resolve_refs(obj: Any) -> Any:
if isinstance(obj, dict):
if '$ref' in obj:
ref = obj.pop('$ref')
ref_name = ref.split('/')[-1]
if ref_name in defs:
# Replace the reference with the actual definition
resolved = defs[ref_name].copy()
# Merge any additional properties from the reference
for key, value in obj.items():
if key != '$ref':
resolved[key] = value
return resolve_refs(resolved)
return obj
else:
# Recursively process all dictionary values
return {k: resolve_refs(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [resolve_refs(item) for item in obj]
return obj
schema = resolve_refs(schema)
# Remove unsupported properties
def clean_schema(obj: Any, parent_key: str | None = None) -> Any:
if isinstance(obj, dict):
# Remove unsupported properties
cleaned = {}
for key, value in obj.items():
# Only strip 'title' when it's a JSON Schema metadata field (not inside 'properties')
# 'title' as a metadata field appears at schema level, not as a property name
is_metadata_title = key == 'title' and parent_key != 'properties'
if key not in ['additionalProperties', 'default'] and not is_metadata_title:
cleaned_value = clean_schema(value, parent_key=key)
# Handle empty object properties - Gemini doesn't allow empty OBJECT types
if (
key == 'properties'
and isinstance(cleaned_value, dict)
and len(cleaned_value) == 0
and isinstance(obj.get('type', ''), str)
and obj.get('type', '').upper() == 'OBJECT'
):
# Convert empty object to have at least one property
cleaned['properties'] = {'_placeholder': {'type': 'string'}}
else:
cleaned[key] = cleaned_value
# If this is an object type with empty properties, add a placeholder
if (
isinstance(cleaned.get('type', ''), str)
and cleaned.get('type', '').upper() == 'OBJECT'
and 'properties' in cleaned
and isinstance(cleaned['properties'], dict)
and len(cleaned['properties']) == 0
):
cleaned['properties'] = {'_placeholder': {'type': 'string'}}
return cleaned
elif isinstance(obj, list):
return [clean_schema(item, parent_key=parent_key) for item in obj]
return obj
return clean_schema(schema)