mirror of
https://github.com/browser-use/browser-use
synced 2026-05-06 17:52:15 +02:00
869 lines
30 KiB
Python
869 lines
30 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import platform
|
|
import signal
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
from collections.abc import Callable, Coroutine
|
|
from fnmatch import fnmatch
|
|
from functools import cache, wraps
|
|
from pathlib import Path
|
|
from sys import stderr
|
|
from typing import Any, Literal, ParamSpec, TypeVar
|
|
from urllib.parse import urlparse
|
|
|
|
import portalocker
|
|
from dotenv import load_dotenv
|
|
|
|
try:
|
|
import psutil
|
|
|
|
PSUTIL_AVAILABLE = True
|
|
except ImportError:
|
|
PSUTIL_AVAILABLE = False
|
|
|
|
load_dotenv()
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Import error types - these may need to be adjusted based on actual import paths
|
|
try:
|
|
from openai import BadRequestError as OpenAIBadRequestError
|
|
except ImportError:
|
|
OpenAIBadRequestError = None
|
|
|
|
try:
|
|
from groq import BadRequestError as GroqBadRequestError # type: ignore[import-not-found]
|
|
except ImportError:
|
|
GroqBadRequestError = None
|
|
|
|
|
|
# Global flag to prevent duplicate exit messages
|
|
_exiting = False
|
|
|
|
# Define generic type variables for return type and parameters
|
|
R = TypeVar('R')
|
|
T = TypeVar('T')
|
|
P = ParamSpec('P')
|
|
|
|
|
|
class SignalHandler:
|
|
"""
|
|
A modular and reusable signal handling system for managing SIGINT (Ctrl+C), SIGTERM,
|
|
and other signals in asyncio applications.
|
|
|
|
This class provides:
|
|
- Configurable signal handling for SIGINT and SIGTERM
|
|
- Support for custom pause/resume callbacks
|
|
- Management of event loop state across signals
|
|
- Standardized handling of first and second Ctrl+C presses
|
|
- Cross-platform compatibility (with simplified behavior on Windows)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
loop: asyncio.AbstractEventLoop | None = None,
|
|
pause_callback: Callable[[], None] | None = None,
|
|
resume_callback: Callable[[], None] | None = None,
|
|
custom_exit_callback: Callable[[], None] | None = None,
|
|
exit_on_second_int: bool = True,
|
|
interruptible_task_patterns: list[str] | None = None,
|
|
):
|
|
"""
|
|
Initialize the signal handler.
|
|
|
|
Args:
|
|
loop: The asyncio event loop to use. Defaults to current event loop.
|
|
pause_callback: Function to call when system is paused (first Ctrl+C)
|
|
resume_callback: Function to call when system is resumed
|
|
custom_exit_callback: Function to call on exit (second Ctrl+C or SIGTERM)
|
|
exit_on_second_int: Whether to exit on second SIGINT (Ctrl+C)
|
|
interruptible_task_patterns: List of patterns to match task names that should be
|
|
canceled on first Ctrl+C (default: ['step', 'multi_act', 'get_next_action'])
|
|
"""
|
|
self.loop = loop or asyncio.get_event_loop()
|
|
self.pause_callback = pause_callback
|
|
self.resume_callback = resume_callback
|
|
self.custom_exit_callback = custom_exit_callback
|
|
self.exit_on_second_int = exit_on_second_int
|
|
self.interruptible_task_patterns = interruptible_task_patterns or ['step', 'multi_act', 'get_next_action']
|
|
self.is_windows = platform.system() == 'Windows'
|
|
|
|
# Initialize loop state attributes
|
|
self._initialize_loop_state()
|
|
|
|
# Store original signal handlers to restore them later if needed
|
|
self.original_sigint_handler = None
|
|
self.original_sigterm_handler = None
|
|
|
|
def _initialize_loop_state(self) -> None:
|
|
"""Initialize loop state attributes used for signal handling."""
|
|
setattr(self.loop, 'ctrl_c_pressed', False)
|
|
setattr(self.loop, 'waiting_for_input', False)
|
|
|
|
def register(self) -> None:
|
|
"""Register signal handlers for SIGINT and SIGTERM."""
|
|
try:
|
|
if self.is_windows:
|
|
# On Windows, use simple signal handling with immediate exit on Ctrl+C
|
|
def windows_handler(sig, frame):
|
|
print('\n\n🛑 Got Ctrl+C. Exiting immediately on Windows...\n', file=stderr)
|
|
# Run the custom exit callback if provided
|
|
if self.custom_exit_callback:
|
|
self.custom_exit_callback()
|
|
os._exit(0)
|
|
|
|
self.original_sigint_handler = signal.signal(signal.SIGINT, windows_handler)
|
|
else:
|
|
# On Unix-like systems, use asyncio's signal handling for smoother experience
|
|
self.original_sigint_handler = self.loop.add_signal_handler(signal.SIGINT, lambda: self.sigint_handler())
|
|
self.original_sigterm_handler = self.loop.add_signal_handler(signal.SIGTERM, lambda: self.sigterm_handler())
|
|
|
|
except Exception:
|
|
# there are situations where signal handlers are not supported, e.g.
|
|
# - when running in a thread other than the main thread
|
|
# - some operating systems
|
|
# - inside jupyter notebooks
|
|
pass
|
|
|
|
def unregister(self) -> None:
|
|
"""Unregister signal handlers and restore original handlers if possible."""
|
|
try:
|
|
if self.is_windows:
|
|
# On Windows, just restore the original SIGINT handler
|
|
if self.original_sigint_handler:
|
|
signal.signal(signal.SIGINT, self.original_sigint_handler)
|
|
else:
|
|
# On Unix-like systems, use asyncio's signal handler removal
|
|
self.loop.remove_signal_handler(signal.SIGINT)
|
|
self.loop.remove_signal_handler(signal.SIGTERM)
|
|
|
|
# Restore original handlers if available
|
|
if self.original_sigint_handler:
|
|
signal.signal(signal.SIGINT, self.original_sigint_handler)
|
|
if self.original_sigterm_handler:
|
|
signal.signal(signal.SIGTERM, self.original_sigterm_handler)
|
|
except Exception as e:
|
|
logger.warning(f'Error while unregistering signal handlers: {e}')
|
|
|
|
def _handle_second_ctrl_c(self) -> None:
|
|
"""
|
|
Handle a second Ctrl+C press by performing cleanup and exiting.
|
|
This is shared logic used by both sigint_handler and wait_for_resume.
|
|
"""
|
|
global _exiting
|
|
|
|
if not _exiting:
|
|
_exiting = True
|
|
|
|
# Call custom exit callback if provided
|
|
if self.custom_exit_callback:
|
|
try:
|
|
self.custom_exit_callback()
|
|
except Exception as e:
|
|
logger.error(f'Error in exit callback: {e}')
|
|
|
|
# Force immediate exit - more reliable than sys.exit()
|
|
print('\n\n🛑 Got second Ctrl+C. Exiting immediately...\n', file=stderr)
|
|
|
|
# Reset terminal to a clean state by sending multiple escape sequences
|
|
# Order matters for terminal resets - we try different approaches
|
|
|
|
# Reset terminal modes for both stdout and stderr
|
|
print('\033[?25h', end='', flush=True, file=stderr) # Show cursor
|
|
print('\033[?25h', end='', flush=True) # Show cursor
|
|
|
|
# Reset text attributes and terminal modes
|
|
print('\033[0m', end='', flush=True, file=stderr) # Reset text attributes
|
|
print('\033[0m', end='', flush=True) # Reset text attributes
|
|
|
|
# Disable special input modes that may cause arrow keys to output control chars
|
|
print('\033[?1l', end='', flush=True, file=stderr) # Reset cursor keys to normal mode
|
|
print('\033[?1l', end='', flush=True) # Reset cursor keys to normal mode
|
|
|
|
# Disable bracketed paste mode
|
|
print('\033[?2004l', end='', flush=True, file=stderr)
|
|
print('\033[?2004l', end='', flush=True)
|
|
|
|
# Carriage return helps ensure a clean line
|
|
print('\r', end='', flush=True, file=stderr)
|
|
print('\r', end='', flush=True)
|
|
|
|
# these ^^ attempts dont work as far as we can tell
|
|
# we still dont know what causes the broken input, if you know how to fix it, please let us know
|
|
print('(tip: press [Enter] once to fix escape codes appearing after chrome exit)', file=stderr)
|
|
|
|
os._exit(0)
|
|
|
|
def sigint_handler(self) -> None:
|
|
"""
|
|
SIGINT (Ctrl+C) handler.
|
|
|
|
First Ctrl+C: Cancel current step and pause.
|
|
Second Ctrl+C: Exit immediately if exit_on_second_int is True.
|
|
"""
|
|
global _exiting
|
|
|
|
if _exiting:
|
|
# Already exiting, force exit immediately
|
|
os._exit(0)
|
|
|
|
if getattr(self.loop, 'ctrl_c_pressed', False):
|
|
# If we're in the waiting for input state, let the pause method handle it
|
|
if getattr(self.loop, 'waiting_for_input', False):
|
|
return
|
|
|
|
# Second Ctrl+C - exit immediately if configured to do so
|
|
if self.exit_on_second_int:
|
|
self._handle_second_ctrl_c()
|
|
|
|
# Mark that Ctrl+C was pressed
|
|
setattr(self.loop, 'ctrl_c_pressed', True)
|
|
|
|
# Cancel current tasks that should be interruptible - this is crucial for immediate pausing
|
|
self._cancel_interruptible_tasks()
|
|
|
|
# Call pause callback if provided - this sets the paused flag
|
|
if self.pause_callback:
|
|
try:
|
|
self.pause_callback()
|
|
except Exception as e:
|
|
logger.error(f'Error in pause callback: {e}')
|
|
|
|
# Log pause message after pause_callback is called (not before)
|
|
print('----------------------------------------------------------------------', file=stderr)
|
|
|
|
def sigterm_handler(self) -> None:
|
|
"""
|
|
SIGTERM handler.
|
|
|
|
Always exits the program completely.
|
|
"""
|
|
global _exiting
|
|
if not _exiting:
|
|
_exiting = True
|
|
print('\n\n🛑 SIGTERM received. Exiting immediately...\n\n', file=stderr)
|
|
|
|
# Call custom exit callback if provided
|
|
if self.custom_exit_callback:
|
|
self.custom_exit_callback()
|
|
|
|
os._exit(0)
|
|
|
|
def _cancel_interruptible_tasks(self) -> None:
|
|
"""Cancel current tasks that should be interruptible."""
|
|
current_task = asyncio.current_task(self.loop)
|
|
for task in asyncio.all_tasks(self.loop):
|
|
if task != current_task and not task.done():
|
|
task_name = task.get_name() if hasattr(task, 'get_name') else str(task)
|
|
# Cancel tasks that match certain patterns
|
|
if any(pattern in task_name for pattern in self.interruptible_task_patterns):
|
|
logger.debug(f'Cancelling task: {task_name}')
|
|
task.cancel()
|
|
# Add exception handler to silence "Task exception was never retrieved" warnings
|
|
task.add_done_callback(lambda t: t.exception() if t.cancelled() else None)
|
|
|
|
# Also cancel the current task if it's interruptible
|
|
if current_task and not current_task.done():
|
|
task_name = current_task.get_name() if hasattr(current_task, 'get_name') else str(current_task)
|
|
if any(pattern in task_name for pattern in self.interruptible_task_patterns):
|
|
logger.debug(f'Cancelling current task: {task_name}')
|
|
current_task.cancel()
|
|
|
|
def wait_for_resume(self) -> None:
|
|
"""
|
|
Wait for user input to resume or exit.
|
|
|
|
This method should be called after handling the first Ctrl+C.
|
|
It temporarily restores default signal handling to allow catching
|
|
a second Ctrl+C directly.
|
|
"""
|
|
# Set flag to indicate we're waiting for input
|
|
setattr(self.loop, 'waiting_for_input', True)
|
|
|
|
# Temporarily restore default signal handling for SIGINT
|
|
# This ensures KeyboardInterrupt will be raised during input()
|
|
original_handler = signal.getsignal(signal.SIGINT)
|
|
try:
|
|
signal.signal(signal.SIGINT, signal.default_int_handler)
|
|
except ValueError:
|
|
# we are running in a thread other than the main thread
|
|
# or signal handlers are not supported for some other reason
|
|
pass
|
|
|
|
green = '\x1b[32;1m'
|
|
red = '\x1b[31m'
|
|
blink = '\033[33;5m'
|
|
unblink = '\033[0m'
|
|
reset = '\x1b[0m'
|
|
|
|
try: # escape code is to blink the ...
|
|
print(
|
|
f'➡️ Press {green}[Enter]{reset} to resume or {red}[Ctrl+C]{reset} again to exit{blink}...{unblink} ',
|
|
end='',
|
|
flush=True,
|
|
file=stderr,
|
|
)
|
|
input() # This will raise KeyboardInterrupt on Ctrl+C
|
|
|
|
# Call resume callback if provided
|
|
if self.resume_callback:
|
|
self.resume_callback()
|
|
except KeyboardInterrupt:
|
|
# Use the shared method to handle second Ctrl+C
|
|
self._handle_second_ctrl_c()
|
|
finally:
|
|
try:
|
|
# Restore our signal handler
|
|
signal.signal(signal.SIGINT, original_handler)
|
|
setattr(self.loop, 'waiting_for_input', False)
|
|
except Exception:
|
|
pass
|
|
|
|
def reset(self) -> None:
|
|
"""Reset state after resuming."""
|
|
# Clear the flags
|
|
if hasattr(self.loop, 'ctrl_c_pressed'):
|
|
setattr(self.loop, 'ctrl_c_pressed', False)
|
|
if hasattr(self.loop, 'waiting_for_input'):
|
|
setattr(self.loop, 'waiting_for_input', False)
|
|
|
|
|
|
def time_execution_sync(additional_text: str = '') -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
|
@wraps(func)
|
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
start_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
execution_time = time.time() - start_time
|
|
# Only log if execution takes more than 0.25 seconds
|
|
if execution_time > 0.25:
|
|
self_has_logger = args and getattr(args[0], 'logger', None)
|
|
if self_has_logger:
|
|
logger = getattr(args[0], 'logger')
|
|
elif 'agent' in kwargs:
|
|
logger = getattr(kwargs['agent'], 'logger')
|
|
elif 'browser_session' in kwargs:
|
|
logger = getattr(kwargs['browser_session'], 'logger')
|
|
else:
|
|
logger = logging.getLogger(__name__)
|
|
logger.debug(f'⏳ {additional_text.strip("-")}() took {execution_time:.2f}s')
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def time_execution_async(
|
|
additional_text: str = '',
|
|
) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]:
|
|
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
|
@wraps(func)
|
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
start_time = time.time()
|
|
result = await func(*args, **kwargs)
|
|
execution_time = time.time() - start_time
|
|
# Only log if execution takes more than 0.25 seconds to avoid spamming the logs
|
|
# you can lower this threshold locally when you're doing dev work to performance optimize stuff
|
|
if execution_time > 0.25:
|
|
self_has_logger = args and getattr(args[0], 'logger', None)
|
|
if self_has_logger:
|
|
logger = getattr(args[0], 'logger')
|
|
elif 'agent' in kwargs:
|
|
logger = getattr(kwargs['agent'], 'logger')
|
|
elif 'browser_session' in kwargs:
|
|
logger = getattr(kwargs['browser_session'], 'logger')
|
|
else:
|
|
logger = logging.getLogger(__name__)
|
|
logger.debug(f'⏳ {additional_text.strip("-")}() took {execution_time:.2f}s')
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def singleton(cls):
|
|
instance = [None]
|
|
|
|
def wrapper(*args, **kwargs):
|
|
if instance[0] is None:
|
|
instance[0] = cls(*args, **kwargs)
|
|
return instance[0]
|
|
|
|
return wrapper
|
|
|
|
|
|
def check_env_variables(keys: list[str], any_or_all=all) -> bool:
|
|
"""Check if all required environment variables are set"""
|
|
return any_or_all(os.getenv(key, '').strip() for key in keys)
|
|
|
|
|
|
def is_unsafe_pattern(pattern: str) -> bool:
|
|
"""
|
|
Check if a domain pattern has complex wildcards that could match too many domains.
|
|
|
|
Args:
|
|
pattern: The domain pattern to check
|
|
|
|
Returns:
|
|
bool: True if the pattern has unsafe wildcards, False otherwise
|
|
"""
|
|
# Extract domain part if there's a scheme
|
|
if '://' in pattern:
|
|
_, pattern = pattern.split('://', 1)
|
|
|
|
# Remove safe patterns (*.domain and domain.*)
|
|
bare_domain = pattern.replace('.*', '').replace('*.', '')
|
|
|
|
# If there are still wildcards, it's potentially unsafe
|
|
return '*' in bare_domain
|
|
|
|
|
|
def match_url_with_domain_pattern(url: str, domain_pattern: str, log_warnings: bool = False) -> bool:
|
|
"""
|
|
Check if a URL matches a domain pattern. SECURITY CRITICAL.
|
|
|
|
Supports optional glob patterns and schemes:
|
|
- *.example.com will match sub.example.com and example.com
|
|
- *google.com will match google.com, agoogle.com, and www.google.com
|
|
- http*://example.com will match http://example.com, https://example.com
|
|
- chrome-extension://* will match chrome-extension://aaaaaaaaaaaa and chrome-extension://bbbbbbbbbbbbb
|
|
|
|
When no scheme is specified, https is used by default for security.
|
|
For example, 'example.com' will match 'https://example.com' but not 'http://example.com'.
|
|
|
|
Note: about:blank must be handled at the callsite, not inside this function.
|
|
|
|
Args:
|
|
url: The URL to check
|
|
domain_pattern: Domain pattern to match against
|
|
log_warnings: Whether to log warnings about unsafe patterns
|
|
|
|
Returns:
|
|
bool: True if the URL matches the pattern, False otherwise
|
|
"""
|
|
try:
|
|
# Note: about:blank should be handled at the callsite, not here
|
|
if url == 'about:blank':
|
|
return False
|
|
|
|
parsed_url = urlparse(url)
|
|
|
|
# Extract only the hostname and scheme components
|
|
scheme = parsed_url.scheme.lower() if parsed_url.scheme else ''
|
|
domain = parsed_url.hostname.lower() if parsed_url.hostname else ''
|
|
|
|
if not scheme or not domain:
|
|
return False
|
|
|
|
# Normalize the domain pattern
|
|
domain_pattern = domain_pattern.lower()
|
|
|
|
# Handle pattern with scheme
|
|
if '://' in domain_pattern:
|
|
pattern_scheme, pattern_domain = domain_pattern.split('://', 1)
|
|
else:
|
|
pattern_scheme = 'https' # Default to matching only https for security
|
|
pattern_domain = domain_pattern
|
|
|
|
# Handle port in pattern (we strip ports from patterns since we already
|
|
# extracted only the hostname from the URL)
|
|
if ':' in pattern_domain and not pattern_domain.startswith(':'):
|
|
pattern_domain = pattern_domain.split(':', 1)[0]
|
|
|
|
# If scheme doesn't match, return False
|
|
if not fnmatch(scheme, pattern_scheme):
|
|
return False
|
|
|
|
# Check for exact match
|
|
if pattern_domain == '*' or domain == pattern_domain:
|
|
return True
|
|
|
|
# Handle glob patterns
|
|
if '*' in pattern_domain:
|
|
# Check for unsafe glob patterns
|
|
# First, check for patterns like *.*.domain which are unsafe
|
|
if pattern_domain.count('*.') > 1 or pattern_domain.count('.*') > 1:
|
|
if log_warnings:
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f'⛔️ Multiple wildcards in pattern=[{domain_pattern}] are not supported')
|
|
return False # Don't match unsafe patterns
|
|
|
|
# Check for wildcards in TLD part (example.*)
|
|
if pattern_domain.endswith('.*'):
|
|
if log_warnings:
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f'⛔️ Wildcard TLDs like in pattern=[{domain_pattern}] are not supported for security')
|
|
return False # Don't match unsafe patterns
|
|
|
|
# Then check for embedded wildcards
|
|
bare_domain = pattern_domain.replace('*.', '')
|
|
if '*' in bare_domain:
|
|
if log_warnings:
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f'⛔️ Only *.domain style patterns are supported, ignoring pattern=[{domain_pattern}]')
|
|
return False # Don't match unsafe patterns
|
|
|
|
# Special handling so that *.google.com also matches bare google.com
|
|
if pattern_domain.startswith('*.'):
|
|
parent_domain = pattern_domain[2:]
|
|
if domain == parent_domain or fnmatch(domain, parent_domain):
|
|
return True
|
|
|
|
# Normal case: match domain against pattern
|
|
if fnmatch(domain, pattern_domain):
|
|
return True
|
|
|
|
return False
|
|
except Exception as e:
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f'⛔️ Error matching URL {url} with pattern {domain_pattern}: {type(e).__name__}: {e}')
|
|
return False
|
|
|
|
|
|
def merge_dicts(a: dict, b: dict, path: tuple[str, ...] = ()):
|
|
for key in b:
|
|
if key in a:
|
|
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
|
merge_dicts(a[key], b[key], path + (str(key),))
|
|
elif isinstance(a[key], list) and isinstance(b[key], list):
|
|
a[key] = a[key] + b[key]
|
|
elif a[key] != b[key]:
|
|
raise Exception('Conflict at ' + '.'.join(path + (str(key),)))
|
|
else:
|
|
a[key] = b[key]
|
|
return a
|
|
|
|
|
|
@cache
|
|
def get_browser_use_version() -> str:
|
|
"""Get the browser-use package version using the same logic as Agent._set_browser_use_version_and_source"""
|
|
try:
|
|
package_root = Path(__file__).parent.parent
|
|
pyproject_path = package_root / 'pyproject.toml'
|
|
|
|
# Try to read version from pyproject.toml
|
|
if pyproject_path.exists():
|
|
import re
|
|
|
|
with open(pyproject_path, encoding='utf-8') as f:
|
|
content = f.read()
|
|
match = re.search(r'version\s*=\s*["\']([^"\']+)["\']', content)
|
|
if match:
|
|
version = f'{match.group(1)}'
|
|
os.environ['LIBRARY_VERSION'] = version # used by bubus event_schema so all Event schemas include versioning
|
|
return version
|
|
|
|
# If pyproject.toml doesn't exist, try getting version from pip
|
|
from importlib.metadata import version as get_version
|
|
|
|
version = str(get_version('browser-use'))
|
|
os.environ['LIBRARY_VERSION'] = version
|
|
return version
|
|
|
|
except Exception as e:
|
|
logger.debug(f'Error detecting browser-use version: {type(e).__name__}: {e}')
|
|
return 'unknown'
|
|
|
|
|
|
def _log_pretty_path(path: str | Path | None) -> str:
|
|
"""Pretty-print a path, shorten home dir to ~ and cwd to ."""
|
|
|
|
if not path or not str(path).strip():
|
|
return '' # always falsy in -> falsy out so it can be used in ternaries
|
|
|
|
# dont print anything thats not a path
|
|
if not isinstance(path, (str, Path)):
|
|
# no other types are safe to just str(path) and log to terminal unless we know what they are
|
|
# e.g. what if we get storage_date=dict | Path and the dict version could contain real cookies
|
|
return f'<{type(path).__name__}>'
|
|
|
|
# replace home dir and cwd with ~ and .
|
|
pretty_path = str(path).replace(str(Path.home()), '~').replace(str(Path.cwd().resolve()), '.')
|
|
|
|
# wrap in quotes if it contains spaces
|
|
if pretty_path.strip() and ' ' in pretty_path:
|
|
pretty_path = f'"{pretty_path}"'
|
|
|
|
return pretty_path
|
|
|
|
|
|
def _log_pretty_url(s: str, max_len: int | None = 22) -> str:
|
|
"""Truncate/pretty-print a URL with a maximum length, removing the protocol and www. prefix"""
|
|
s = s.replace('https://', '').replace('http://', '').replace('www.', '')
|
|
if max_len is not None and len(s) > max_len:
|
|
return s[:max_len] + '…'
|
|
return s
|
|
|
|
|
|
# Global semaphore registry for retry decorator
|
|
GLOBAL_RETRY_SEMAPHORES: dict[str, asyncio.Semaphore] = {}
|
|
GLOBAL_RETRY_SEMAPHORE_LOCK = threading.Lock()
|
|
|
|
# Multiprocess semaphore support
|
|
MULTIPROCESS_SEMAPHORE_DIR = Path(tempfile.gettempdir()) / 'browser_use_semaphores'
|
|
MULTIPROCESS_SEMAPHORE_DIR.mkdir(exist_ok=True)
|
|
|
|
# Global multiprocess semaphore registry
|
|
MULTIPROCESS_SEMAPHORES: dict[str, portalocker.utils.NamedBoundedSemaphore] = {}
|
|
MULTIPROCESS_SEMAPHORE_LOCK = threading.Lock()
|
|
|
|
# Global overload detection state
|
|
_last_overload_check = 0.0
|
|
_overload_check_interval = 5.0 # Check every 5 seconds
|
|
_active_retry_operations = 0
|
|
_active_operations_lock = threading.Lock()
|
|
|
|
|
|
def _check_system_overload() -> tuple[bool, str]:
|
|
"""Check if system is overloaded and return (is_overloaded, reason)"""
|
|
if not PSUTIL_AVAILABLE:
|
|
return False, ''
|
|
|
|
try:
|
|
# Get system stats
|
|
cpu_percent = psutil.cpu_percent(interval=0.1)
|
|
memory = psutil.virtual_memory()
|
|
|
|
# Check thresholds
|
|
reasons = []
|
|
is_overloaded = False
|
|
|
|
if cpu_percent > 85:
|
|
is_overloaded = True
|
|
reasons.append(f'CPU: {cpu_percent:.1f}%')
|
|
|
|
if memory.percent > 85:
|
|
is_overloaded = True
|
|
reasons.append(f'Memory: {memory.percent:.1f}%')
|
|
|
|
# Check number of concurrent operations
|
|
with _active_operations_lock:
|
|
if _active_retry_operations > 30:
|
|
is_overloaded = True
|
|
reasons.append(f'Active operations: {_active_retry_operations}')
|
|
|
|
return is_overloaded, ', '.join(reasons)
|
|
except Exception:
|
|
return False, ''
|
|
|
|
|
|
def retry(
|
|
wait: float = 3,
|
|
retries: int = 3,
|
|
timeout: float = 5,
|
|
retry_on: tuple[type[Exception], ...] | None = None,
|
|
backoff_factor: float = 1.0,
|
|
semaphore_limit: int | None = None,
|
|
semaphore_name: str | None = None,
|
|
semaphore_lax: bool = True,
|
|
semaphore_scope: Literal['global', 'class', 'self', 'multiprocess'] = 'global',
|
|
semaphore_timeout: float | None = None,
|
|
):
|
|
"""
|
|
Retry decorator with semaphore support for async functions.
|
|
|
|
Args:
|
|
wait: Seconds to wait between retries
|
|
retries: Number of retry attempts after initial failure
|
|
timeout: Per-attempt timeout in seconds
|
|
retry_on: Tuple of exception types to retry on (None = retry all exceptions)
|
|
backoff_factor: Multiplier for wait time after each retry (1.0 = no backoff)
|
|
semaphore_limit: Max concurrent executions (creates semaphore if needed)
|
|
semaphore_name: Name for semaphore (defaults to function name)
|
|
semaphore_lax: If True, continue without semaphore on acquisition failure
|
|
semaphore_scope: Scope for semaphore sharing:
|
|
- 'global': All calls share one semaphore (default)
|
|
- 'class': All instances of a class share one semaphore
|
|
- 'self': Each instance gets its own semaphore
|
|
- 'multiprocess': All processes on the machine share one semaphore
|
|
semaphore_timeout: Max time to wait for semaphore acquisition (None = timeout * (limit - 1))
|
|
|
|
Example:
|
|
@retry(wait=3, retries=3, timeout=5, semaphore_limit=3, semaphore_scope='self')
|
|
async def some_function(self, ...):
|
|
# Limited to 5s per attempt, retries up to 3 times on failure
|
|
# Max 3 concurrent executions per instance
|
|
|
|
Notes:
|
|
- semaphore aquision happens once at start time, it's not retried
|
|
- semaphore_timeout is only used if semaphore_limit is set.
|
|
- if semaphore_timeout is set to 0, it will wait forever for a semaphore slot to become available.
|
|
- if semaphore_timeout is set to None, it will wait for the default (timeout * (semaphore_limit - 1)) +0.01s
|
|
- retries are 0-indexed, so retries=1 means the function will be called 2 times total (1 initial + 1 retry)
|
|
"""
|
|
|
|
def decorator(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
|
|
@wraps(func)
|
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # type: ignore[return]
|
|
# Get or create semaphore if needed
|
|
semaphore: Any = None
|
|
semaphore_acquired = False
|
|
multiprocess_lock: Any = None
|
|
|
|
if semaphore_limit is not None:
|
|
# Determine semaphore key based on scope
|
|
base_name = semaphore_name or func.__name__
|
|
|
|
if semaphore_scope == 'multiprocess':
|
|
# Use multiprocess semaphore
|
|
sem_key = base_name
|
|
with MULTIPROCESS_SEMAPHORE_LOCK:
|
|
if sem_key not in MULTIPROCESS_SEMAPHORES:
|
|
# Create a NamedBoundedSemaphore with the given limit
|
|
MULTIPROCESS_SEMAPHORES[sem_key] = portalocker.utils.NamedBoundedSemaphore(
|
|
maximum=semaphore_limit,
|
|
name=sem_key,
|
|
directory=str(MULTIPROCESS_SEMAPHORE_DIR),
|
|
)
|
|
semaphore = MULTIPROCESS_SEMAPHORES[sem_key]
|
|
else:
|
|
# Use in-process semaphore
|
|
if semaphore_scope == 'global':
|
|
sem_key = base_name
|
|
elif semaphore_scope == 'class' and args and hasattr(args[0], '__class__'):
|
|
# Use class name for class-level semaphore
|
|
class_name = args[0].__class__.__name__
|
|
sem_key = f'{class_name}.{base_name}'
|
|
elif semaphore_scope == 'self' and args:
|
|
# Use instance id for instance-level semaphore
|
|
instance_id = id(args[0])
|
|
sem_key = f'{instance_id}.{base_name}'
|
|
else:
|
|
# Fallback to global if we can't determine scope
|
|
sem_key = base_name
|
|
|
|
with GLOBAL_RETRY_SEMAPHORE_LOCK:
|
|
if sem_key not in GLOBAL_RETRY_SEMAPHORES:
|
|
GLOBAL_RETRY_SEMAPHORES[sem_key] = asyncio.Semaphore(semaphore_limit)
|
|
semaphore = GLOBAL_RETRY_SEMAPHORES[sem_key]
|
|
|
|
# Try to acquire semaphore
|
|
sem_start = time.time()
|
|
# Calculate semaphore timeout
|
|
if semaphore_timeout is None:
|
|
# Default: wait time is if all other slots are occupied with max timeout operations
|
|
# Ensure minimum of timeout value when limit=1
|
|
sem_timeout = max(timeout, timeout * (semaphore_limit - 1))
|
|
else:
|
|
# Use provided timeout, but ensure minimum of 0.01 if 0 was passed
|
|
sem_timeout = max(0.01, semaphore_timeout) if semaphore_timeout == 0 else semaphore_timeout
|
|
|
|
if semaphore_scope == 'multiprocess':
|
|
# Multiprocess semaphore uses context manager with timeout
|
|
try:
|
|
# Use a temporary thread to run the blocking operation
|
|
# portalocker returns a Lock object or None
|
|
multiprocess_lock = await asyncio.to_thread(
|
|
lambda: semaphore.acquire(timeout=sem_timeout, check_interval=0.05, fail_when_locked=False)
|
|
)
|
|
if multiprocess_lock:
|
|
semaphore_acquired = True
|
|
else:
|
|
raise TimeoutError(f'Could not acquire semaphore within {sem_timeout}s')
|
|
except Exception as e:
|
|
sem_wait_time = time.time() - sem_start
|
|
if 'Could not acquire' in str(e) or isinstance(e, TimeoutError):
|
|
if not semaphore_lax:
|
|
raise TimeoutError(
|
|
f'Failed to acquire multiprocess semaphore "{sem_key}" within {sem_timeout}s '
|
|
f'(limit={semaphore_limit}, timeout={timeout}s per operation)'
|
|
)
|
|
logger.warning(
|
|
f'Failed to acquire multiprocess semaphore "{sem_key}" after {sem_wait_time:.1f}s, '
|
|
f'proceeding without concurrency limit'
|
|
)
|
|
else:
|
|
raise
|
|
else:
|
|
# Regular asyncio semaphore
|
|
try:
|
|
async with asyncio.timeout(sem_timeout):
|
|
await semaphore.acquire()
|
|
semaphore_acquired = True
|
|
except TimeoutError:
|
|
sem_wait_time = time.time() - sem_start
|
|
if not semaphore_lax:
|
|
raise TimeoutError(
|
|
f'Failed to acquire semaphore "{sem_key}" within {sem_timeout}s '
|
|
f'(limit={semaphore_limit}, timeout={timeout}s per operation)'
|
|
)
|
|
logger.warning(
|
|
f'Failed to acquire semaphore "{sem_key}" after {sem_wait_time:.1f}s, '
|
|
f'proceeding without concurrency limit'
|
|
)
|
|
|
|
# Track active operations
|
|
global _last_overload_check, _active_retry_operations
|
|
with _active_operations_lock:
|
|
_active_retry_operations += 1
|
|
|
|
# Check for system overload (rate limited)
|
|
current_time = time.time()
|
|
if current_time - _last_overload_check > _overload_check_interval:
|
|
_last_overload_check = current_time
|
|
is_overloaded, reason = _check_system_overload()
|
|
if is_overloaded:
|
|
logger.warning(
|
|
f'⚠️ System overload detected: {reason}. Consider reducing concurrent operations to prevent hanging.'
|
|
)
|
|
|
|
# Execute function with retries
|
|
start_time = time.time()
|
|
last_exception = None
|
|
|
|
try:
|
|
for attempt in range(retries + 1):
|
|
try:
|
|
# Execute with per-attempt timeout
|
|
async with asyncio.timeout(timeout):
|
|
return await func(*args, **kwargs)
|
|
|
|
except Exception as e:
|
|
# Check if we should retry this exception
|
|
if retry_on is not None and not isinstance(e, retry_on):
|
|
raise
|
|
|
|
last_exception = e
|
|
|
|
if attempt < retries:
|
|
# Calculate wait time with backoff
|
|
current_wait = wait * (backoff_factor**attempt)
|
|
|
|
logger.warning(
|
|
f'{func.__name__} failed (attempt {attempt + 1}/{retries + 1}): '
|
|
f'{type(e).__name__}: {e}. Waiting {current_wait:.1f}s before retry...'
|
|
)
|
|
await asyncio.sleep(current_wait)
|
|
else:
|
|
# Final failure
|
|
total_time = time.time() - start_time
|
|
sem_wait = time.time() - sem_start - total_time if semaphore_limit else 0
|
|
|
|
logger.error(
|
|
f'{func.__name__} failed after {retries + 1} attempts over {total_time:.1f}s. '
|
|
f'Semaphore wait: {sem_wait:.1f}s. Final error: {type(e).__name__}: {e}'
|
|
)
|
|
raise
|
|
|
|
finally:
|
|
# Decrement active operations counter
|
|
with _active_operations_lock:
|
|
_active_retry_operations = max(0, _active_retry_operations - 1)
|
|
|
|
if semaphore_acquired and semaphore:
|
|
if semaphore_scope == 'multiprocess' and multiprocess_lock:
|
|
# Release the lock object for portalocker
|
|
await asyncio.to_thread(lambda: multiprocess_lock.release())
|
|
elif semaphore:
|
|
semaphore.release()
|
|
|
|
return wrapper
|
|
|
|
return decorator
|