diff --git a/.env.example b/.env.example index dfd481610..413c04aae 100644 --- a/.env.example +++ b/.env.example @@ -4,5 +4,5 @@ ANTHROPIC_API_KEY= # Set to false to disable anonymized telemetry ANONYMIZED_TELEMETRY=true -# Set to true to enable verbose logging -BROWSER_USE_DEBUG_LOGGING=true \ No newline at end of file +# LogLevel: Set to debug to enable verbose logging, set to result to get results only. Available: result | debug | info +BROWSER_USE_LOGGING_LEVEL=info \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 5f930522e..022f02a10 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -60,14 +60,26 @@ "justMyCode": false }, { - "name": "Python: Debug Agent History", + "name": "Python: Debug Core Functionality", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/.venv/bin/pytest", + "args": [ + "tests/test_core_functionality.py", + "-v" + ], + "console": "integratedTerminal", + "justMyCode": false + }, + { + "name": "Python: Debug Current File", "type": "python", "request": "launch", "module": "pytest", "args": [ - "browser_use/agent/tests.py", + "${file}", "-v", - "--capture=no", + "--capture=no" ], "console": "integratedTerminal", "justMyCode": false diff --git a/README.md b/README.md index e8a4adbb6..8ca2754f1 100644 --- a/README.md +++ b/README.md @@ -17,19 +17,29 @@ With pip: pip install browser-use ``` +(optional) install playwright: + +```bash +playwright install +``` + Spin up your agent: ```python from langchain_openai import ChatOpenAI from browser_use import Agent +import asyncio -agent = Agent( - task="Find a one-way flight from Bali to Oman on 12 January 2025 on Google Flights. Return me the cheapest option.", - llm=ChatOpenAI(model="gpt-4o"), -) - -# ... inside an async function -await agent.run() +async def main(): + agent = Agent( + task="Find a one-way flight from Bali to Oman on 12 January 2025 on Google Flights. Return me the cheapest option.", + llm=ChatOpenAI(model="gpt-4o"), + ) + result = await agent.run() + print(result) + +if __name__ == "__main__": + asyncio.run(main()) ``` And don't forget to add your API keys to your `.env` file. @@ -71,6 +81,8 @@ https://github.com/user-attachments/assets/de73ee39-432c-4b97-b4e8-939fd7f323b3 If you want to add custom actions your agent can take, you can register them like this: +You can use BOTH sync or async functions. + ```python from browser_use.agent.service import Agent from browser_use.browser.service import Browser @@ -94,11 +106,12 @@ class JobDetails(BaseModel): salary: Optional[str] = None @controller.action('Save job details which you found on page', param_model=JobDetails, requires_browser=True) -def save_job(params: JobDetails, browser: Browser): +async def save_job(params: JobDetails, browser: Browser): print(params) # use the browser normally - browser.driver.get(params.job_link) + page = browser.get_current_page() + page.go_to(params.job_link) ``` and then run your agent: diff --git a/browser_use/__init__.py b/browser_use/__init__.py index cefa744e3..05c0eedc5 100644 --- a/browser_use/__init__.py +++ b/browser_use/__init__.py @@ -8,4 +8,4 @@ from browser_use.browser.service import Browser as Browser from browser_use.controller.service import Controller as Controller from browser_use.dom.service import DomService -__all__ = ['Agent', 'Browser', 'Controller', 'DomService', 'SystemPrompt'] +__all__ = ["Agent", "Browser", "Controller", "DomService", "SystemPrompt"] diff --git a/browser_use/agent/service.py b/browser_use/agent/service.py index 1da87c2cc..9dec94c29 100644 --- a/browser_use/agent/service.py +++ b/browser_use/agent/service.py @@ -117,13 +117,15 @@ class Agent: async def step(self) -> None: """Execute one step of the task""" logger.info(f'\nšŸ“ Step {self.n_steps}') - state = self.controller.browser.get_state(use_vision=self.use_vision) + state = await self.controller.browser.get_state(use_vision=self.use_vision) try: model_output = await self.get_next_action(state) - result = self.controller.act(model_output.action) + result = await self.controller.act(model_output.action) if result.extracted_content: logger.info(f'šŸ“„ Result: {result.extracted_content}') + if result.is_done: + logger.result(f'{result.extracted_content}') self.consecutive_failures = 0 except Exception as e: @@ -410,7 +412,7 @@ class Agent: ) ) if not self.controller_injected: - self.controller.browser.close() + await self.controller.browser.close() def _too_many_failures(self) -> bool: """Check if we should stop due to too many failures""" diff --git a/browser_use/browser/service.py b/browser_use/browser/service.py index 6a3b7ad7a..ee7a0796a 100644 --- a/browser_use/browser/service.py +++ b/browser_use/browser/service.py @@ -1,26 +1,17 @@ """ -Selenium browser on steroids. +Playwright browser on steroids. """ +import asyncio import base64 import logging -import os -import tempfile import time -from typing import Literal +from dataclasses import dataclass -from Screenshot import Screenshot -from selenium import webdriver -from selenium.webdriver.chrome.options import Options -from selenium.webdriver.chrome.service import Service as ChromeService -from selenium.webdriver.common.action_chains import ActionChains -from selenium.webdriver.common.by import By -from selenium.webdriver.remote.webelement import WebElement -from selenium.webdriver.support import expected_conditions as EC -from selenium.webdriver.support.ui import WebDriverWait -from webdriver_manager.chrome import ChromeDriverManager +from playwright.async_api import Browser as PlaywrightBrowser +from playwright.async_api import BrowserContext, ElementHandle, Page, Playwright, async_playwright -from browser_use.browser.views import BrowserState, TabInfo +from browser_use.browser.views import BrowserError, BrowserState, TabInfo from browser_use.dom.service import DomService from browser_use.dom.views import SelectorMap from browser_use.utils import time_execution_sync @@ -28,119 +19,158 @@ from browser_use.utils import time_execution_sync logger = logging.getLogger(__name__) +@dataclass +class BrowserSession: + playwright: Playwright + browser: PlaywrightBrowser + context: BrowserContext + current_page: Page + cached_state: BrowserState + # current_page_id: str + # opened_tabs: dict[str, TabInfo] = field(default_factory=dict) + + class Browser: + MINIMUM_WAIT_TIME = 0.5 + MAXIMUM_WAIT_TIME = 5 + def __init__(self, headless: bool = False, keep_open: bool = False): self.headless = headless self.keep_open = keep_open - self.MINIMUM_WAIT_TIME = 0.5 - self.MAXIMUM_WAIT_TIME = 5 - self._tab_cache: dict[str, TabInfo] = {} - self._current_handle = None - self._ob = Screenshot.Screenshot() - # Initialize driver during construction - self.driver: webdriver.Chrome | None = self._setup_webdriver() - self._cached_state = self._update_state() + # Initialize these as None - they'll be set up when needed + self.session: BrowserSession | None = None - def _setup_webdriver(self) -> webdriver.Chrome: - """Sets up and returns a Selenium WebDriver instance with anti-detection measures.""" + async def _initialize_session(self): + """Initialize the browser session""" + playwright = await async_playwright().start() + browser = await self._setup_browser(playwright) + context = await self._create_context(browser) + page = await context.new_page() + + # Instead of calling _update_state(), create an empty initial state + initial_state = BrowserState( + items=[], + selector_map={}, + url=page.url, + title=await page.title(), + screenshot=None, + tabs=[], + ) + + self.session = BrowserSession( + playwright=playwright, + browser=browser, + context=context, + current_page=page, + cached_state=initial_state, + ) + + return self.session + + async def get_session(self) -> BrowserSession: + """Lazy initialization of the browser and related components""" + if self.session is None: + return await self._initialize_session() + return self.session + + async def get_current_page(self) -> Page: + """Get the current page""" + session = await self.get_session() + return session.current_page + + async def _setup_browser(self, playwright: Playwright) -> PlaywrightBrowser: + """Sets up and returns a Playwright Browser instance with anti-detection measures.""" try: - # if webdriver is not starting, try to kill it or rm -rf ~/.wdm - chrome_options = Options() - if self.headless: - chrome_options.add_argument('--headless=new') # Updated headless argument - - # Essential automation and performance settings - chrome_options.add_argument('--disable-blink-features=AutomationControlled') - chrome_options.add_experimental_option('excludeSwitches', ['enable-automation']) - chrome_options.add_experimental_option('useAutomationExtension', False) - chrome_options.add_argument('--no-sandbox') - chrome_options.add_argument('--window-size=1280,1024') - chrome_options.add_argument('--disable-extensions') - - # Background process optimization - chrome_options.add_argument('--disable-background-timer-throttling') - chrome_options.add_argument('--disable-popup-blocking') - - # Additional stealth settings - chrome_options.add_argument('--disable-infobars') - # Much better when working in non-headless mode - chrome_options.add_argument('--disable-backgrounding-occluded-windows') - chrome_options.add_argument('--disable-renderer-backgrounding') - - # Initialize the Chrome driver with better error handling - service = ChromeService(ChromeDriverManager().install()) - driver = webdriver.Chrome(service=service, options=chrome_options) - - # Execute stealth scripts - driver.execute_cdp_cmd( - 'Page.addScriptToEvaluateOnNewDocument', - { - 'source': """ - Object.defineProperty(navigator, 'webdriver', { - get: () => undefined - }); - - Object.defineProperty(navigator, 'languages', { - get: () => ['en-US', 'en'] - }); - - Object.defineProperty(navigator, 'plugins', { - get: () => [1, 2, 3, 4, 5] - }); - - window.chrome = { - runtime: {} - }; - - Object.defineProperty(navigator, 'permissions', { - get: () => ({ - query: Promise.resolve.bind(Promise) - }) - }); - """ - }, + browser = await playwright.chromium.launch( + headless=self.headless, + ignore_default_args=['--enable-automation'], # Helps with anti-detection + args=[ + '--no-sandbox', + '--disable-blink-features=AutomationControlled', + '--disable-extensions', + '--disable-infobars', + '--disable-background-timer-throttling', + '--disable-popup-blocking', + '--disable-backgrounding-occluded-windows', + '--disable-renderer-backgrounding', + '--disable-window-activation', + '--disable-focus-on-load', # Prevents focus on navigation + '--no-first-run', + '--no-default-browser-check', + '--no-startup-window', # Prevents initial focus + '--window-position=0,0', + ], ) - return driver - + return browser except Exception as e: - logger.error(f'Failed to initialize Chrome driver: {str(e)}') - # Clean up any existing driver - if hasattr(self, 'driver') and self.driver: - try: - self.driver.quit() - self.driver = None - except Exception: - pass + logger.error(f'Failed to initialize Playwright browser: {str(e)}') raise - def _get_driver(self) -> webdriver.Chrome: - if self.driver is None: - self.driver = self._setup_webdriver() - return self.driver + async def _create_context(self, browser: PlaywrightBrowser): + """Creates a new browser context with anti-detection measures.""" + context = await browser.new_context( + viewport={'width': 1280, 'height': 1024}, + user_agent=( + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/85.0.4183.102 Safari/537.36' + ), + java_script_enabled=True, + ) - def wait_for_page_load(self): + # Expose anti-detection scripts + await context.add_init_script( + """ + // Webdriver property + Object.defineProperty(navigator, 'webdriver', { + get: () => undefined + }); + + // Languages + Object.defineProperty(navigator, 'languages', { + get: () => ['en-US', 'en'] + }); + + // Plugins + Object.defineProperty(navigator, 'plugins', { + get: () => [1, 2, 3, 4, 5] + }); + + // Chrome runtime + window.chrome = { runtime: {} }; + + // Permissions + const originalQuery = window.navigator.permissions.query; + window.navigator.permissions.query = (parameters) => ( + parameters.name === 'notifications' ? + Promise.resolve({ state: Notification.permission }) : + originalQuery(parameters) + ); + """ + ) + + return context + + async def wait_for_page_load(self, timeout_overwrite: float | None = None): """ Ensures page is fully loaded before continuing. Waits for either document.readyState to be complete or minimum WAIT_TIME, whichever is longer. """ - driver = self._get_driver() + page = await self.get_current_page() # Start timing start_time = time.time() # Wait for page load try: - WebDriverWait(driver, 5).until( - lambda d: d.execute_script('return document.readyState') == 'complete' - ) + await page.wait_for_load_state('load', timeout=5000) except Exception: pass # Calculate remaining time to meet minimum WAIT_TIME elapsed = time.time() - start_time - remaining = max(self.MINIMUM_WAIT_TIME - elapsed, 0) + remaining = max((timeout_overwrite or self.MINIMUM_WAIT_TIME) - elapsed, 0) logger.debug( f'--Page loaded in {elapsed:.2f} seconds, waiting for additional {remaining:.2f} seconds' @@ -148,111 +178,148 @@ class Browser: # Sleep remaining time if needed if remaining > 0: - time.sleep(remaining) + await asyncio.sleep(remaining) - def _update_state(self, use_vision: bool = False) -> BrowserState: - """ - Update and return state. - """ - driver = self._get_driver() - dom_service = DomService(driver) - content = dom_service.get_clickable_elements() + async def close(self, force: bool = False): + """Close the browser instance""" + if force and not self.keep_open: + session = await self.get_session() + await session.browser.close() + await session.playwright.stop() + else: + # Note: input() is blocking - consider an async alternative if needed + input('Press Enter to close Browser...') + self.keep_open = False + await self.close(force=True) + + def __del__(self): + """Async cleanup when object is destroyed""" + if self.session is not None: + asyncio.run(self.close(force=True)) + + async def navigate_to(self, url: str): + """Navigate to a URL""" + page = await self.get_current_page() + await page.goto(url) + await self.wait_for_page_load() + + async def refresh_page(self): + """Refresh the current page""" + page = await self.get_current_page() + await page.reload() + await self.wait_for_page_load() + + async def go_back(self): + """Navigate back in history""" + page = await self.get_current_page() + await page.go_back() + await self.wait_for_page_load() + + async def go_forward(self): + """Navigate forward in history""" + page = await self.get_current_page() + await page.go_forward() + await self.wait_for_page_load() + + async def close_current_tab(self): + """Close the current tab""" + session = await self.get_session() + page = session.current_page + await page.close() + + # Switch to the first available tab if any exist + if session.context.pages: + await self.switch_to_tab(0) + + # otherwise the browser will be closed + + async def get_page_html(self) -> str: + """Get the current page HTML content""" + page = await self.get_current_page() + return await page.content() + + async def execute_javascript(self, script: str): + """Execute JavaScript code on the page""" + page = await self.get_current_page() + return await page.evaluate(script) + + @time_execution_sync('--get_state') # This decorator might need to be updated to handle async + async def get_state(self, use_vision: bool = False) -> BrowserState: + """Get the current state of the browser""" + session = await self.get_session() + session.cached_state = await self._update_state(use_vision=use_vision) + return session.cached_state + + async def _update_state(self, use_vision: bool = False) -> BrowserState: + """Update and return state.""" + page = await self.get_current_page() + dom_service = DomService(page) + content = await dom_service.get_clickable_elements() # Assuming this is async screenshot_b64 = None if use_vision: - screenshot_b64 = self.take_screenshot(selector_map=content.selector_map) + screenshot_b64 = await self.take_screenshot(selector_map=content.selector_map) self.current_state = BrowserState( items=content.items, selector_map=content.selector_map, - url=driver.current_url, - title=driver.title, - current_tab_handle=self._current_handle or driver.current_window_handle, - tabs=self.get_tabs_info(), + url=page.url, + title=await page.title(), + tabs=await self.get_tabs_info(), screenshot=screenshot_b64, ) return self.current_state - def close(self, force: bool = False): - if not self.keep_open or force: - if self.driver: - driver = self._get_driver() - driver.quit() - self.driver = None - else: - input('Press Enter to close Browser...') - self.keep_open = False - self.close() - - def __del__(self): - """ - Close the browser driver when instance is destroyed. - """ - if self.driver is not None: - self.close() - # region - Browser Actions - def take_screenshot(self, selector_map: SelectorMap | None, full_page: bool = False) -> str: + async def take_screenshot( + self, selector_map: SelectorMap | None, full_page: bool = False + ) -> str: """ Returns a base64 encoded screenshot of the current page. """ - driver = self._get_driver() + page = await self.get_current_page() if selector_map: - self.highlight_selector_map_elements(selector_map) + await self.highlight_selector_map_elements(selector_map) - if full_page: - # Create temp directory - temp_dir = tempfile.mkdtemp() - screenshot = self._ob.full_screenshot( - driver, - save_path=temp_dir, - image_name='temp.png', - is_load_at_runtime=True, - load_wait_time=1, - ) + screenshot = await page.screenshot( + full_page=full_page, + animations='disabled', + ) - # Read file as base64 - with open(os.path.join(temp_dir, 'temp.png'), 'rb') as img: - screenshot = base64.b64encode(img.read()).decode('utf-8') - - # Cleanup temp directory - os.remove(os.path.join(temp_dir, 'temp.png')) - os.rmdir(temp_dir) - else: - screenshot = driver.get_screenshot_as_base64() + screenshot_b64 = base64.b64encode(screenshot).decode('utf-8') if selector_map: - self.remove_highlights() + await self.remove_highlights() - return screenshot + return screenshot_b64 - def highlight_selector_map_elements(self, selector_map: SelectorMap): - driver = self._get_driver() - # First remove any existing highlights/labels - self.remove_highlights() + async def highlight_selector_map_elements(self, selector_map: SelectorMap): + page = await self.get_current_page() + await self.remove_highlights() script = """ const highlights = { """ - # Build the highlights object with all xpaths and indices - for index, xpath in selector_map.items(): - script += f'"{index}": "{xpath}",\n' + # Build the highlights object with all selectors and indices + for index, selector in selector_map.items(): + # Adjusting the JavaScript code to accept variables + script += f'"{index}": "{selector}",\n' script += """ }; - for (const [index, xpath] of Object.entries(highlights)) { - const el = document.evaluate(xpath, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue; + for (const [index, selector] of Object.entries(highlights)) { + const el = document.evaluate(selector, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue; if (!el) continue; // Skip if element not found el.style.outline = "2px solid red"; - el.setAttribute('browser-user-highlight-id', 'selenium-highlight'); + el.setAttribute('browser-user-highlight-id', 'playwright-highlight'); const label = document.createElement("div"); - label.className = 'selenium-highlight-label'; + label.className = 'playwright-highlight-label'; label.style.position = "fixed"; label.style.background = "red"; label.style.color = "white"; @@ -268,25 +335,26 @@ class Browser: } """ - driver.execute_script(script) + await page.evaluate(script) - def remove_highlights(self): + async def remove_highlights(self): """ Removes all highlight outlines and labels created by highlight_selector_map_elements + """ - driver = self._get_driver() - driver.execute_script( + page = await self.get_current_page() + await page.evaluate( """ // Remove all highlight outlines - const highlightedElements = document.querySelectorAll('[browser-user-highlight-id="selenium-highlight"]'); + const highlightedElements = document.querySelectorAll('[browser-user-highlight-id="playwright-highlight"]'); highlightedElements.forEach(el => { el.style.outline = ''; - el.removeAttribute('selenium-browser-use-highlight'); + el.removeAttribute('browser-user-highlight-id'); }); // Remove all labels - const labels = document.querySelectorAll('.selenium-highlight-label'); + const labels = document.querySelectorAll('.playwright-highlight-label'); labels.forEach(label => label.remove()); """ ) @@ -294,78 +362,50 @@ class Browser: # endregion # region - User Actions - def _webdriver_wait(self): - driver = self._get_driver() - return WebDriverWait(driver, 10) - def _input_text_by_xpath(self, xpath: str, text: str): - driver = self._get_driver() + async def _input_text_by_xpath(self, xpath: str, text: str): + page = await self.get_current_page() try: - # Wait for element to be both present and interactable - element = self._webdriver_wait().until(EC.element_to_be_clickable((By.XPATH, xpath))) + element = await page.wait_for_selector(f'xpath={xpath}', timeout=5000, state='visible') - # Scroll element into view using ActionChains for smoother scrolling - actions = ActionChains(driver) - actions.move_to_element(element).perform() + if element is None: + raise Exception(f'Element with xpath: {xpath} not found') - # Try to clear using JavaScript first - driver.execute_script("arguments[0].value = '';", element) - - # Then send keys - element.send_keys(text) - - self.wait_for_page_load() + await element.scroll_into_view_if_needed(timeout=2500) + await element.fill('') + await element.type(text) + await self.wait_for_page_load() except Exception as e: raise Exception( f'Failed to input text into element with xpath: {xpath}. Error: {str(e)}' ) - def _click_element_by_xpath(self, xpath: str): + async def _click_element_by_xpath(self, xpath: str): """ Optimized method to click an element using xpath. """ - driver = self._get_driver() - wait = self._webdriver_wait() + page = await self.get_current_page() try: - # First try the direct approach with a shorter timeout + element = await page.wait_for_selector(f'xpath={xpath}', timeout=5000, state='visible') + + if element is None: + raise Exception(f'Element with xpath: {xpath} not found') + + # await element.scroll_into_view_if_needed() + try: - element = wait.until( - EC.element_to_be_clickable((By.XPATH, xpath)), - message=f'Element not clickable: {xpath}', - ) - element.click() - self.wait_for_page_load() + await element.click(timeout=2500) + await self.wait_for_page_load() return except Exception: pass - # If that fails, try a simplified approach try: - # Try with ID if present in xpath - if 'id=' in xpath: - id_value = xpath.split('id=')[-1].split(']')[0] - element = driver.find_element(By.ID, id_value) - if element.is_displayed() and element.is_enabled(): - driver.execute_script('arguments[0].click();', element) - self.wait_for_page_load() - return - except Exception: - pass - - # Last resort: force click with JavaScript - try: - element = driver.find_element(By.XPATH, xpath) - driver.execute_script( - """ - arguments[0].scrollIntoView({behavior: 'instant', block: 'center'}); - arguments[0].click(); - """, - element, - ) - self.wait_for_page_load() + await page.evaluate('(el) => el.click()', element) + await self.wait_for_page_load() return except Exception as e: raise Exception(f'Failed to click element: {str(e)}') @@ -373,71 +413,63 @@ class Browser: except Exception as e: raise Exception(f'Failed to click element with xpath: {xpath}. Error: {str(e)}') - def handle_new_tab(self) -> None: - """Handle newly opened tab and switch to it""" - driver = self._get_driver() - handles = driver.window_handles - new_handle = handles[-1] # Get most recently opened handle - - # Switch to new tab - driver.switch_to.window(new_handle) - self._current_handle = new_handle - - # Wait for page load - self.wait_for_page_load() - - # Create and cache tab info - tab_info = TabInfo(handle=new_handle, url=driver.current_url, title=driver.title) - self._tab_cache[new_handle] = tab_info - - def get_tabs_info(self) -> list[TabInfo]: + async def get_tabs_info(self) -> list[TabInfo]: """Get information about all tabs""" - driver = self._get_driver() - current_handle = driver.current_window_handle - self._current_handle = current_handle + session = await self.get_session() tabs_info = [] - for handle in driver.window_handles: - # Use cached info if available, otherwise get new info - if handle in self._tab_cache: - tab_info = self._tab_cache[handle] - else: - # Only switch if we need to get info - if handle != current_handle: - driver.switch_to.window(handle) - tab_info = TabInfo(handle=handle, url=driver.current_url, title=driver.title) - self._tab_cache[handle] = tab_info - + for page_id, page in enumerate(session.context.pages): + tab_info = TabInfo(page_id=page_id, url=page.url, title=await page.title()) tabs_info.append(tab_info) - # Switch back to current tab if we moved - if driver.current_window_handle != current_handle: - driver.switch_to.window(current_handle) - return tabs_info + async def switch_to_tab(self, page_id: int) -> None: + """Switch to a specific tab by its page_id + + @You can also use negative indices to switch to tabs from the end (Pure pythonic way) + """ + session = await self.get_session() + pages = session.context.pages + + if page_id >= len(pages): + raise BrowserError(f'No tab found with page_id: {page_id}') + + page = pages[page_id] + session.current_page = page + + await page.bring_to_front() + await self.wait_for_page_load() + + async def create_new_tab(self, url: str | None = None) -> None: + """Create a new tab and optionally navigate to a URL""" + session = await self.get_session() + new_page = await session.context.new_page() + session.current_page = new_page + + await self.wait_for_page_load() + + page = await self.get_current_page() + + if url: + await page.goto(url) + await self.wait_for_page_load(timeout_overwrite=1) + # endregion - @time_execution_sync('--get_state') - def get_state(self, use_vision: bool = False) -> BrowserState: - """ - Get the current state of the browser including page content and tab information. - """ - self._cached_state = self._update_state(use_vision=use_vision) - return self._cached_state + # region - Helper methods for easier access to the DOM + async def get_selector_map(self) -> SelectorMap: + session = await self.get_session() + return session.cached_state.selector_map - @property - def selector_map(self) -> SelectorMap: - return self._cached_state.selector_map + async def get_xpath(self, index: int) -> str: + selector_map = await self.get_selector_map() + return selector_map[index] - def xpath(self, index: int) -> str: - return self.selector_map[index] + async def get_element_by_index(self, index: int) -> ElementHandle | None: + page = await self.get_current_page() + return await page.wait_for_selector( + await self.get_xpath(index), timeout=2500, state='visible' + ) - def get_element(self, index: int) -> WebElement: - driver = self._get_driver() - return driver.find_element(By.XPATH, self.xpath(index)) - - def wait_for_element(self, css_selector: str, timeout: int = 10) -> WebElement: - """Wait for an element to appear and return it.""" - wait = WebDriverWait(self._get_driver(), timeout) - return wait.until(EC.presence_of_element_located((By.CSS_SELECTOR, css_selector))) + # endregion diff --git a/browser_use/browser/tests/playwright_test.py b/browser_use/browser/tests/playwright_test.py new file mode 100644 index 000000000..1107841e0 --- /dev/null +++ b/browser_use/browser/tests/playwright_test.py @@ -0,0 +1,52 @@ +import time + +import pytest +from playwright.async_api import Page + +from browser_use.dom.service import DomService + + +@pytest.fixture(scope='session') +def browser_type_launch_args(): + return {'headless': False} + + +async def test_has_title(page: Page): + dom_service = DomService(page) + + await page.goto('https://www.immobilienscout24.de') + await page.wait_for_timeout(2000) + + # Get all DOM content including all shadow roots recursively + start_time = time.time() + full_content = await dom_service._get_html_content() + # full_content = page.evaluate("""() => { + # function getAllContent(root) { + # let content = ''; + # // Get all elements in the current root + # const elements = root.querySelectorAll('*'); + + # elements.forEach(element => { + # // Add the element's outer HTML + # content += element.outerHTML; + # // If element has shadow root, recursively get its content + # if (element.shadowRoot) { + # content += `\\n\\n`; + # content += getAllContent(element.shadowRoot); + # content += `\\n\\n`; + # } + # }); + # return content; + # } + # return getAllContent(document); + # }""") + end_time = time.time() + + print(full_content) + print(f'Time taken to get DOM content: {end_time - start_time:.2f} seconds') + + elements = dom_service._process_content(full_content) + + print(elements) + + input('Press Enter to continue...') diff --git a/browser_use/browser/tests/screenshot_test.py b/browser_use/browser/tests/screenshot_test.py index 3f479c37c..084e64458 100644 --- a/browser_use/browser/tests/screenshot_test.py +++ b/browser_use/browser/tests/screenshot_test.py @@ -6,10 +6,11 @@ from browser_use.browser.service import Browser @pytest.fixture -def browser(): +async def browser(): browser_service = Browser(headless=True) yield browser_service - browser_service.close() + + await browser_service.close() # @pytest.mark.skip(reason='takes too long') diff --git a/browser_use/browser/tests/test_clicks.py b/browser_use/browser/tests/test_clicks.py index 99e1d541f..248491827 100644 --- a/browser_use/browser/tests/test_clicks.py +++ b/browser_use/browser/tests/test_clicks.py @@ -1,59 +1,69 @@ import time -from browser_use.browser.service import Browser +import pytest + +from browser_use.browser.service import Browser from browser_use.utils import time_execution_sync -def test_highlight_elements(): - browser = Browser() +@pytest.mark.asyncio +async def test_highlight_elements(): + browser = Browser(headless=False, keep_open=False) - browser._get_driver().get('https://kayak.com') - # browser.go_to_url('https://google.com/flights') - # browser.go_to_url('https://immobilienscout24.de') + session = await browser.get_session() - time.sleep(1) - # browser._click_element_by_xpath( - # '/html/body/div[5]/div/div[2]/div/div/div[3]/div/div[1]/button[1]' - # ) - # browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]") + print(session) - while True: - state = browser.get_state() + page = await browser.get_current_page() + # await page.goto('https://immobilienscout24.de') + await page.goto("https://kayak.com") - time_execution_sync('highlight_selector_map_elements')( - browser.highlight_selector_map_elements - )(state.selector_map) + time.sleep(3) + # browser._click_element_by_xpath( + # '/html/body/div[5]/div/div[2]/div/div/div[3]/div/div[1]/button[1]' + # ) + # browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]") - print(state.dom_items_to_string(use_tabs=False)) - # print(state.selector_map) + while True: + state = await browser.get_state() - # Find and print duplicate XPaths - xpath_counts = {} - for selector in state.selector_map.values(): - if selector in xpath_counts: - xpath_counts[selector] += 1 - else: - xpath_counts[selector] = 1 + await time_execution_sync("highlight_selector_map_elements")( + browser.highlight_selector_map_elements + )(state.selector_map) - print('\nDuplicate XPaths found:') - for xpath, count in xpath_counts.items(): - if count > 1: - print(f'XPath: {xpath}') - print(f'Count: {count}\n') + print(state.dom_items_to_string(use_tabs=False)) + # print(state.selector_map) - print(state.selector_map.keys(), 'Selector map keys') - action = input('Select next action: ') + # Find and print duplicate XPaths + xpath_counts = {} + for selector in state.selector_map.values(): + if selector in xpath_counts: + xpath_counts[selector] += 1 + else: + xpath_counts[selector] = 1 - time_execution_sync('remove_highlight_elements')(browser.remove_highlights)() + print("\nDuplicate XPaths found:") + for xpath, count in xpath_counts.items(): + if count > 1: + print(f"XPath: {xpath}") + print(f"Count: {count}\n") - xpath = state.selector_map[int(action)] + print(state.selector_map.keys(), "Selector map keys") + action = input("Select next action: ") - browser._click_element_by_xpath(xpath) + await time_execution_sync("remove_highlight_elements")( + browser.remove_highlights + )() + xpath = state.selector_map[int(action)] -def main(): - test_highlight_elements() + # check if index of selector map are the same as index of items in dom_items + indcies = list(state.selector_map.keys()) + dom_items = state.items + dom_indices = [item.index for item in dom_items if not item.is_text_only] + assert ( + indcies == dom_indices + ), "Indices of selector map and dom items do not match" -if __name__ == '__main__': - main() + await browser._click_element_by_xpath(xpath) diff --git a/browser_use/browser/tests/test_selenium.py b/browser_use/browser/tests/test_selenium.py deleted file mode 100644 index ae0b51f1b..000000000 --- a/browser_use/browser/tests/test_selenium.py +++ /dev/null @@ -1,50 +0,0 @@ -import time - -import pytest -from selenium import webdriver -from selenium.webdriver.chrome.options import Options -from selenium.webdriver.chrome.service import Service -from webdriver_manager.chrome import ChromeDriverManager - - -def test_selenium(): - try: - print('1. Setting up Chrome options...') - chrome_options = Options() - chrome_options.add_argument('--no-sandbox') - # Uncomment to test headless mode - # chrome_options.add_argument('--headless=new') - - print('2. Installing/finding ChromeDriver...') - service = Service(ChromeDriverManager().install()) - - print('3. Creating Chrome WebDriver...') - driver = webdriver.Chrome(service=service, options=chrome_options) - - print('4. Navigating to Google...') - driver.get('https://www.google.com') - - print('5. Getting page title...') - title = driver.title - print(f'Page title: {title}') - - time.sleep(2) # Wait to see the page if not in headless mode - - print('6. Closing browser...') - driver.quit() - - print('āœ… Test completed successfully!') - return True - - except Exception as e: - print(f'āŒ Test failed with error: {str(e)}') - print(f'Error type: {type(e).__name__}') - return False - - -# run with: pytest browser_use/browser/tests/test_selenium.py - -# - -if __name__ == '__main__': - pytest.main([__file__, '-v']) diff --git a/browser_use/browser/views.py b/browser_use/browser/views.py index 244d6b0c4..f87819dd8 100644 --- a/browser_use/browser/views.py +++ b/browser_use/browser/views.py @@ -9,7 +9,7 @@ from browser_use.dom.views import ProcessedDomContent class TabInfo(BaseModel): """Represents information about a browser tab""" - handle: str + page_id: int url: str title: str @@ -17,7 +17,6 @@ class TabInfo(BaseModel): class BrowserState(ProcessedDomContent): url: str title: str - current_tab_handle: str tabs: list[TabInfo] screenshot: Optional[str] = None @@ -29,3 +28,7 @@ class BrowserState(ProcessedDomContent): f'Tab {i+1}: {tab.title} ({tab.url})' for i, tab in enumerate(self.tabs) ] return dump + + +class BrowserError(Exception): + """Base class for all browser errors""" diff --git a/browser_use/controller/registry/service.py b/browser_use/controller/registry/service.py index 032ab316a..a153fd63b 100644 --- a/browser_use/controller/registry/service.py +++ b/browser_use/controller/registry/service.py @@ -1,4 +1,5 @@ -from inspect import signature +import asyncio +from inspect import iscoroutinefunction, signature from typing import Any, Callable, Optional, Type from pydantic import BaseModel, create_model @@ -50,10 +51,24 @@ class Registry: # Create param model from function if not provided actual_param_model = param_model or self._create_param_model(func) + # Wrap sync functions to make them async + if not iscoroutinefunction(func): + + async def async_wrapper(*args, **kwargs): + return await asyncio.to_thread(func, *args, **kwargs) + + # Copy the signature and other metadata from the original function + async_wrapper.__signature__ = signature(func) + async_wrapper.__name__ = func.__name__ + async_wrapper.__annotations__ = func.__annotations__ + wrapped_func = async_wrapper + else: + wrapped_func = func + action = RegisteredAction( name=func.__name__, description=description, - function=func, + function=wrapped_func, param_model=actual_param_model, requires_browser=requires_browser, ) @@ -62,7 +77,7 @@ class Registry: return decorator - def execute_action( + async def execute_action( self, action_name: str, params: dict, browser: Optional[Browser] = None ) -> Any: """Execute a registered action""" @@ -79,17 +94,19 @@ class Registry: parameters = list(sig.parameters.values()) is_pydantic = parameters and issubclass(parameters[0].annotation, BaseModel) - # Execute with or without browser + # Prepare arguments based on parameter type if action.requires_browser: if not browser: - raise ValueError(f'Action {action_name} requires browser but none provided') + raise ValueError( + f'Action {action_name} requires browser but none provided. This has to be used in combination of `requires_browser=True` when registering the action.' + ) if is_pydantic: - return action.function(validated_params, browser=browser) - return action.function(**validated_params.model_dump(), browser=browser) + return await action.function(validated_params, browser=browser) + return await action.function(**validated_params.model_dump(), browser=browser) if is_pydantic: - return action.function(validated_params) - return action.function(**validated_params.model_dump()) + return await action.function(validated_params) + return await action.function(**validated_params.model_dump()) except Exception as e: raise RuntimeError(f'Error executing action {action_name}: {str(e)}') from e diff --git a/browser_use/controller/service.py b/browser_use/controller/service.py index 3c8461bf0..b759baf9d 100644 --- a/browser_use/controller/service.py +++ b/browser_use/controller/service.py @@ -1,12 +1,9 @@ import logging from main_content_extractor import MainContentExtractor -from selenium.webdriver.common.by import By -from selenium.webdriver.common.keys import Keys from browser_use.agent.views import ActionModel, ActionResult from browser_use.browser.service import Browser -from browser_use.browser.views import TabInfo from browser_use.controller.registry.service import Registry from browser_use.controller.views import ( ClickElementAction, @@ -15,7 +12,7 @@ from browser_use.controller.views import ( GoToUrlAction, InputTextAction, OpenTabAction, - ScrollDownAction, + ScrollAction, SearchGoogleAction, SwitchTabAction, ) @@ -25,8 +22,8 @@ logger = logging.getLogger(__name__) class Controller: - def __init__(self, keep_open: bool = False): - self.browser = Browser(keep_open=keep_open) + def __init__(self, headless: bool = False, keep_open: bool = False): + self.browser = Browser(headless=headless, keep_open=keep_open) self.registry = Registry() self._register_default_actions() @@ -37,29 +34,30 @@ class Controller: @self.registry.action( 'Search Google', param_model=SearchGoogleAction, requires_browser=True ) - def search_google(params: SearchGoogleAction, browser: Browser): - driver = browser._get_driver() - driver.get(f'https://www.google.com/search?q={params.query}') - browser.wait_for_page_load() + async def search_google(params: SearchGoogleAction, browser: Browser): + page = await browser.get_current_page() + await page.goto(f'https://www.google.com/search?q={params.query}') + await browser.wait_for_page_load() @self.registry.action('Navigate to URL', param_model=GoToUrlAction, requires_browser=True) - def go_to_url(params: GoToUrlAction, browser: Browser): - driver = browser._get_driver() - driver.get(params.url) - browser.wait_for_page_load() + async def go_to_url(params: GoToUrlAction, browser: Browser): + page = await browser.get_current_page() + await page.goto(params.url) + await browser.wait_for_page_load() @self.registry.action('Go back', requires_browser=True) - def go_back(browser: Browser): - driver = browser._get_driver() - driver.back() - browser.wait_for_page_load() + async def go_back(browser: Browser): + page = await browser.get_current_page() + await page.go_back() + await browser.wait_for_page_load() # Element Interaction Actions @self.registry.action( 'Click element', param_model=ClickElementAction, requires_browser=True ) - def click_element(params: ClickElementAction, browser: Browser): - state = browser._cached_state + async def click_element(params: ClickElementAction, browser: Browser): + session = await browser.get_session() + state = session.cached_state if params.index not in state.selector_map: print(state.selector_map) @@ -68,14 +66,13 @@ class Controller: ) xpath = state.selector_map[params.index] - driver = browser._get_driver() - initial_handles = len(driver.window_handles) + initial_pages = len(session.context.pages) msg = None for _ in range(params.num_clicks): try: - browser._click_element_by_xpath(xpath) + await browser._click_element_by_xpath(xpath) msg = f'šŸ–±ļø Clicked element {params.index}: {xpath}' if params.num_clicks > 1: msg += f' ({_ + 1}/{params.num_clicks} clicks)' @@ -83,50 +80,36 @@ class Controller: logger.warning(f'Element no longer available after {_ + 1} clicks: {str(e)}') break - if len(driver.window_handles) > initial_handles: - browser.handle_new_tab() + if len(session.context.pages) > initial_pages: + await browser.switch_to_tab(-1) return ActionResult(extracted_content=f'Clicked element {msg}') @self.registry.action('Input text', param_model=InputTextAction, requires_browser=True) - def input_text(params: InputTextAction, browser: Browser): - state = browser._cached_state + async def input_text(params: InputTextAction, browser: Browser): + session = await browser.get_session() + state = session.cached_state + if params.index not in state.selector_map: raise Exception( f'Element index {params.index} does not exist - retry or use alternative actions' ) xpath = state.selector_map[params.index] - browser._input_text_by_xpath(xpath, params.text) + await browser._input_text_by_xpath(xpath, params.text) msg = f'āŒØļø Input text "{params.text}" into element {params.index}: {xpath}' return ActionResult(extracted_content=msg) # Tab Management Actions @self.registry.action('Switch tab', param_model=SwitchTabAction, requires_browser=True) - def switch_tab(params: SwitchTabAction, browser: Browser): - driver = browser._get_driver() - - # Verify handle exists - if params.handle not in driver.window_handles: - raise ValueError(f'Tab handle {params.handle} not found') - - # Only switch if we're not already on that tab - if params.handle != driver.current_window_handle: - driver.switch_to.window(params.handle) - browser._current_handle = params.handle - # Wait for tab to be ready - browser.wait_for_page_load() - - # Update and return tab info - tab_info = TabInfo(handle=params.handle, url=driver.current_url, title=driver.title) - browser._tab_cache[params.handle] = tab_info + async def switch_tab(params: SwitchTabAction, browser: Browser): + await browser.switch_to_tab(params.page_id) + # Wait for tab to be ready + await browser.wait_for_page_load() @self.registry.action('Open new tab', param_model=OpenTabAction, requires_browser=True) - def open_tab(params: OpenTabAction, browser: Browser): - driver = browser._get_driver() - driver.execute_script(f'window.open("{params.url}", "_blank");') - browser.wait_for_page_load() - browser.handle_new_tab() + async def open_tab(params: OpenTabAction, browser: Browser): + await browser.create_new_tab(params.url) # Content Actions @self.registry.action( @@ -134,42 +117,46 @@ class Controller: param_model=ExtractPageContentAction, requires_browser=True, ) - def extract_content(params: ExtractPageContentAction, browser: Browser): - driver = browser._get_driver() + async def extract_content(params: ExtractPageContentAction, browser: Browser): + page = await browser.get_current_page() content = MainContentExtractor.extract( # type: ignore - html=driver.page_source, + html=await page.content(), output_format=params.value, ) return ActionResult(extracted_content=content) @self.registry.action('Complete task', param_model=DoneAction, requires_browser=True) - def done(params: DoneAction, browser: Browser): - logger.info(f'āœ… Done on page {browser._cached_state.url}\n\n: {params.text}') + async def done(params: DoneAction, browser: Browser): + session = await browser.get_session() + state = session.cached_state + logger.info(f'āœ… Done on page {state.url}\n\n: {params.text}') return ActionResult(is_done=True, extracted_content=params.text) @self.registry.action( 'Scroll down the page by pixel amount - if no amount is specified, scroll down one page', - param_model=ScrollDownAction, + param_model=ScrollAction, requires_browser=True, ) - def scroll_down(params: ScrollDownAction, browser: Browser): - driver = browser._get_driver() + async def scroll_down(params: ScrollAction, browser: Browser): + page = await browser.get_current_page() if params.amount is not None: - driver.execute_script(f'window.scrollBy(0, {params.amount});') + await page.evaluate(f'window.scrollBy(0, {params.amount});') else: - body = driver.find_element(By.TAG_NAME, 'body') - body.send_keys(Keys.PAGE_DOWN) + await page.keyboard.press('PageDown') # scroll up @self.registry.action( 'Scroll up the page by pixel amount', - param_model=ScrollDownAction, + param_model=ScrollAction, requires_browser=True, ) - def scroll_up(params: ScrollDownAction, browser: Browser): - driver = browser._get_driver() - driver.execute_script(f'window.scrollBy(0, -{params.amount});') + async def scroll_up(params: ScrollAction, browser: Browser): + page = await browser.get_current_page() + if params.amount is not None: + await page.evaluate(f'window.scrollBy(0, -{params.amount});') + else: + await page.keyboard.press('PageUp') def action(self, description: str, **kwargs): """Decorator for registering custom actions @@ -179,12 +166,14 @@ class Controller: return self.registry.action(description, **kwargs) @time_execution_sync('--act') - def act(self, action: ActionModel) -> ActionResult: + async def act(self, action: ActionModel) -> ActionResult: """Execute an action""" try: for action_name, params in action.model_dump(exclude_unset=True).items(): if params is not None: - result = self.registry.execute_action(action_name, params, browser=self.browser) + result = await self.registry.execute_action( + action_name, params, browser=self.browser + ) if isinstance(result, str): return ActionResult(extracted_content=result) elif isinstance(result, ActionResult): diff --git a/browser_use/controller/views.py b/browser_use/controller/views.py index 342647480..02bbec176 100644 --- a/browser_use/controller/views.py +++ b/browser_use/controller/views.py @@ -27,7 +27,7 @@ class DoneAction(BaseModel): class SwitchTabAction(BaseModel): - handle: str + page_id: int class OpenTabAction(BaseModel): @@ -38,7 +38,5 @@ class ExtractPageContentAction(BaseModel): value: Literal['text', 'markdown', 'html'] = 'text' -class ScrollDownAction(BaseModel): - amount: Optional[int] = ( - None # The number of pixels to scroll down. If None, scroll down one page - ) +class ScrollAction(BaseModel): + amount: Optional[int] = None # The number of pixels to scroll. If None, scroll down/up one page diff --git a/browser_use/dom/service.py b/browser_use/dom/service.py index 71f93a442..aaf7d6b1c 100644 --- a/browser_use/dom/service.py +++ b/browser_use/dom/service.py @@ -3,36 +3,62 @@ import logging from typing import Optional from bs4 import BeautifulSoup, NavigableString, PageElement, Tag -from selenium import webdriver -from selenium.webdriver.remote.webelement import WebElement +from playwright.async_api import Page from browser_use.dom.views import ( BatchCheckResults, DomContentItem, ElementCheckResult, - ElementState, ProcessedDomContent, TextCheckResult, - TextState, ) -from browser_use.utils import time_execution_sync +from browser_use.utils import time_execution_async logger = logging.getLogger(__name__) class DomService: - def __init__(self, driver: webdriver.Chrome): - self.driver = driver - self.xpath_cache = {} # Add cache at instance level - - def get_clickable_elements(self) -> ProcessedDomContent: - # Clear xpath cache on each new DOM processing + def __init__(self, page: Page): + self.page = page self.xpath_cache = {} - html_content = self.driver.page_source - return self._process_content(html_content) - @time_execution_sync('--_process_content') - def _process_content(self, html_content: str) -> ProcessedDomContent: + async def get_clickable_elements(self) -> ProcessedDomContent: + self.xpath_cache = {} + html_content = await self._get_html_content() + return await self._process_content(html_content) + + async def _get_html_content(self, with_shadow_roots: bool = True) -> str: + """ + Get all DOM content including all shadow roots recursively. + + @param with_shadow_roots: If you want to include shadow roots in the content it's a bit slower but worth it in most cases. + """ + if with_shadow_roots: + full_content = await self.page.evaluate("""() => { + function getAllContent(root) { + let content = root.innerHTML || ''; + + // Get all elements with shadow roots + const elements = root.querySelectorAll('*'); + elements.forEach(element => { + if (element.shadowRoot) { + // Add a marker for shadow root start + content += ``; + content += getAllContent(element.shadowRoot); + content += ''; + } + }); + + return content; + } + + return `${getAllContent(document.body)}`; + }""") + return full_content + return await self.page.content() + + @time_execution_async('--_process_content') + async def _process_content(self, html_content: str) -> ProcessedDomContent: soup = BeautifulSoup(html_content, 'html.parser') output_items: list[DomContentItem] = [] @@ -86,8 +112,8 @@ class DomService: xpath_order_counter += 1 # Batch check all elements - element_results = self._batch_check_elements(interactive_elements) - text_results = self._batch_check_texts(text_nodes) + element_results = await self._batch_check_elements(interactive_elements) + text_results = await self._batch_check_texts(text_nodes) # Create ordered results ordered_results: list[ @@ -138,13 +164,14 @@ class DomService: return ProcessedDomContent(items=output_items, selector_map=selector_map) - def _batch_check_elements(self, elements: dict[str, tuple[Tag, int]]) -> BatchCheckResults: - """Batch check all interactive elements at once.""" + async def _batch_check_elements( + self, elements: dict[str, tuple[Tag, int]] + ) -> BatchCheckResults: if not elements: return BatchCheckResults(elements={}, texts={}) check_script = """ - return (function() { + (function() { const results = {}; const elements = %s; @@ -152,15 +179,13 @@ class DomService: const element = document.evaluate(xpath, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue; if (!element) continue; - // Check visibility - const isVisible = element.checkVisibility({ - checkOpacity: true, - checkVisibilityCSS: true - }); + const isVisible = element.offsetWidth > 0 && + element.offsetHeight > 0 && + window.getComputedStyle(element).visibility !== 'hidden' && + window.getComputedStyle(element).display !== 'none'; if (!isVisible) continue; - // Check if topmost const rect = element.getBoundingClientRect(); const points = [ {x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.25}, @@ -193,7 +218,7 @@ class DomService: """ % json.dumps({xpath: {} for xpath in elements.keys()}) try: - results = self.driver.execute_script(check_script) + results = await self.page.evaluate(check_script) return BatchCheckResults( elements={xpath: ElementCheckResult(**data) for xpath, data in results.items()}, texts={}, @@ -202,15 +227,14 @@ class DomService: logger.error('Error in batch element check: %s', e) return BatchCheckResults(elements={}, texts={}) - def _batch_check_texts( + async def _batch_check_texts( self, texts: dict[str, tuple[NavigableString, int]] ) -> BatchCheckResults: - """Batch check all text nodes at once.""" if not texts: return BatchCheckResults(elements={}, texts={}) check_script = """ - return (function() { + (function() { const results = {}; const texts = %s; @@ -256,7 +280,7 @@ class DomService: ) try: - results = self.driver.execute_script(check_script) + results = await self.page.evaluate(check_script) return BatchCheckResults( elements={}, texts={xpath: TextCheckResult(**data) for xpath, data in results.items()}, diff --git a/browser_use/dom/tests/extraction_test.py b/browser_use/dom/tests/extraction_test.py index 11cea675b..33c97120c 100644 --- a/browser_use/dom/tests/extraction_test.py +++ b/browser_use/dom/tests/extraction_test.py @@ -8,14 +8,14 @@ from browser_use.utils import time_execution_sync # @pytest.mark.skip("slow af") -def test_process_html_file(): +async def test_process_html_file(): browser = Browser(headless=False) - driver = browser._get_driver() + page = await browser.get_current_page() - dom_service = DomService(driver) + dom_service = DomService(page) - driver.get('https://kayak.com/flights') + await page.goto('https://kayak.com/flights') # browser.go_to_url('https://google.com/flights') # browser.go_to_url('https://immobilienscout24.de') @@ -25,7 +25,9 @@ def test_process_html_file(): # ) # browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]") - elements = time_execution_sync('get_clickable_elements')(dom_service.get_clickable_elements)() + elements = await time_execution_sync('get_clickable_elements')( + dom_service.get_clickable_elements + )() print(elements.dom_items_to_string(use_tabs=False)) print('Tokens:', count_string_tokens(elements.dom_items_to_string(), model='gpt-4o')) diff --git a/browser_use/logging_config.py b/browser_use/logging_config.py index 461703930..3e0d4f5b6 100644 --- a/browser_use/logging_config.py +++ b/browser_use/logging_config.py @@ -3,52 +3,121 @@ import os import sys +def addLoggingLevel(levelName, levelNum, methodName=None): + """ + Comprehensively adds a new logging level to the `logging` module and the + currently configured logging class. + + `levelName` becomes an attribute of the `logging` module with the value + `levelNum`. `methodName` becomes a convenience method for both `logging` + itself and the class returned by `logging.getLoggerClass()` (usually just + `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is + used. + + To avoid accidental clobberings of existing attributes, this method will + raise an `AttributeError` if the level name is already an attribute of the + `logging` module or if the method name is already present + + Example + ------- + >>> addLoggingLevel('TRACE', logging.DEBUG - 5) + >>> logging.getLogger(__name__).setLevel("TRACE") + >>> logging.getLogger(__name__).trace('that worked') + >>> logging.trace('so did this') + >>> logging.TRACE + 5 + + """ + if not methodName: + methodName = levelName.lower() + + if hasattr(logging, levelName): + raise AttributeError("{} already defined in logging module".format(levelName)) + if hasattr(logging, methodName): + raise AttributeError("{} already defined in logging module".format(methodName)) + if hasattr(logging.getLoggerClass(), methodName): + raise AttributeError("{} already defined in logger class".format(methodName)) + + # This method was inspired by the answers to Stack Overflow post + # http://stackoverflow.com/q/2183233/2988730, especially + # http://stackoverflow.com/a/13638084/2988730 + def logForLevel(self, message, *args, **kwargs): + if self.isEnabledFor(levelNum): + self._log(levelNum, message, args, **kwargs) + + def logToRoot(message, *args, **kwargs): + logging.log(levelNum, message, *args, **kwargs) + + logging.addLevelName(levelNum, levelName) + setattr(logging, levelName, levelNum) + setattr(logging.getLoggerClass(), methodName, logForLevel) + setattr(logging, methodName, logToRoot) + + def setup_logging(): - debug_logging = os.getenv('BROWSER_USE_DEBUG_LOGGING', 'false').lower() == 'true' + # Try to add RESULT level, but ignore if it already exists + try: + addLoggingLevel("RESULT", 35) # This allows ERROR, FATAL and CRITICAL + except AttributeError: + pass # Level already exists, which is fine - # Check if handlers are already set up - if logging.getLogger().hasHandlers(): - return + log_type = os.getenv("BROWSER_USE_LOGGING_LEVEL", "info").lower() - # Clear existing handlers - root = logging.getLogger() - root.handlers = [] + # Check if handlers are already set up + if logging.getLogger().hasHandlers(): + return - class BrowserUseFormatter(logging.Formatter): - def format(self, record): - if record.name.startswith('browser_use.'): - record.name = record.name.split('.')[-2] - return super().format(record) + # Clear existing handlers + root = logging.getLogger() + root.handlers = [] - # Setup single handler for all loggers - console = logging.StreamHandler(sys.stdout) - console.setFormatter(BrowserUseFormatter('%(levelname)-8s [%(name)s] %(message)s')) + class BrowserUseFormatter(logging.Formatter): + def format(self, record): + if record.name.startswith("browser_use."): + record.name = record.name.split(".")[-2] + return super().format(record) - # Configure root logger only - root.addHandler(console) + # Setup single handler for all loggers + console = logging.StreamHandler(sys.stdout) - if debug_logging: - root.setLevel(logging.DEBUG) - else: - root.setLevel(logging.INFO) + # adittional setLevel here to filter logs + if log_type == "result": + console.setLevel("RESULT") + console.setFormatter(BrowserUseFormatter("%(message)s")) + else: + console.setFormatter( + BrowserUseFormatter("%(levelname)-8s [%(name)s] %(message)s") + ) - # Configure browser_use logger to prevent propagation - browser_use_logger = logging.getLogger('browser_use') - browser_use_logger.propagate = False - browser_use_logger.addHandler(console) + # Configure root logger only + root.addHandler(console) - # Silence third-party loggers - for logger in [ - 'WDM', - 'httpx', - 'selenium', - 'urllib3', - 'asyncio', - 'langchain', - 'openai', - 'httpcore', - 'charset_normalizer', - ]: - third_party = logging.getLogger(logger) - third_party.setLevel(logging.ERROR) - third_party.propagate = False + # switch cases for log_type + if log_type == "result": + root.setLevel("RESULT") # string usage to avoid syntax error + elif log_type == "debug": + root.setLevel(logging.DEBUG) + else: + root.setLevel(logging.INFO) + + # Configure browser_use logger to prevent propagation + browser_use_logger = logging.getLogger("browser_use") + browser_use_logger.propagate = False + browser_use_logger.addHandler(console) + + # Silence third-party loggers + for logger in [ + "WDM", + "httpx", + "selenium", + "playwright", + "urllib3", + "asyncio", + "langchain", + "openai", + "httpcore", + "charset_normalizer", + ]: + third_party = logging.getLogger(logger) + third_party.setLevel(logging.ERROR) + third_party.propagate = False diff --git a/browser_use/telemetry/service.py b/browser_use/telemetry/service.py index bbb2b69fc..d63f59e2d 100644 --- a/browser_use/telemetry/service.py +++ b/browser_use/telemetry/service.py @@ -34,7 +34,7 @@ class ProductTelemetry: def __init__(self) -> None: telemetry_disabled = os.getenv('ANONYMIZED_TELEMETRY', 'true').lower() == 'false' - self.debug_logging = os.getenv('BROWSER_USE_DEBUG_LOGGING', 'false').lower() == 'true' + self.debug_logging = os.getenv('BROWSER_USE_LOGGING_LEVEL', 'info').lower() == 'debug' if telemetry_disabled: self._posthog_client = None diff --git a/browser_use/telemetry/views.py b/browser_use/telemetry/views.py index 6ce12c17e..1e69c1770 100644 --- a/browser_use/telemetry/views.py +++ b/browser_use/telemetry/views.py @@ -5,47 +5,47 @@ from typing import Any, Dict, Optional @dataclass class BaseTelemetryEvent(ABC): - @property - @abstractmethod - def name(self) -> str: - pass + @property + @abstractmethod + def name(self) -> str: + pass - @property - def properties(self) -> Dict[str, Any]: - return {k: v for k, v in asdict(self).items() if k != 'name'} + @property + def properties(self) -> Dict[str, Any]: + return {k: v for k, v in asdict(self).items() if k != "name"} @dataclass class RegisteredFunction: - name: str - params: dict[str, Any] + name: str + params: dict[str, Any] @dataclass class ControllerRegisteredFunctionsTelemetryEvent(BaseTelemetryEvent): - registered_functions: list[RegisteredFunction] - name: str = 'controller_registered_functions' + registered_functions: list[RegisteredFunction] + name: str = "controller_registered_functions" @dataclass class AgentRunTelemetryEvent(BaseTelemetryEvent): - agent_id: str - task: str - name: str = 'agent_run' + agent_id: str + task: str + name: str = "agent_run" @dataclass class AgentStepErrorTelemetryEvent(BaseTelemetryEvent): - agent_id: str - error: str - name: str = 'agent_step_error' + agent_id: str + error: str + name: str = "agent_step_error" @dataclass class AgentEndTelemetryEvent(BaseTelemetryEvent): - agent_id: str - task: str - steps: int - success: bool - error: Optional[str] = None - name: str = 'agent_end' + agent_id: str + task: str + steps: int + success: bool + error: Optional[str] = None + name: str = "agent_end" diff --git a/examples/check_appointment.py b/examples/check_appointment.py index 9de69f687..b182d3b1b 100644 --- a/examples/check_appointment.py +++ b/examples/check_appointment.py @@ -1,16 +1,13 @@ import asyncio -from typing import List, Optional import os +import dotenv from langchain_openai import ChatOpenAI - +from pydantic import BaseModel, SecretStr from browser_use.agent.service import Agent -from browser_use.browser.service import Browser from browser_use.controller.service import Controller -from pydantic import BaseModel -import dotenv dotenv.load_dotenv() @@ -18,28 +15,26 @@ controller = Controller() class WebpageInfo(BaseModel): - link: str = "https://appointment.mfa.gr/en/reservations/aero/ireland-grcon-dub/" + link: str = 'https://appointment.mfa.gr/en/reservations/aero/ireland-grcon-dub/' - -@controller.action("Go to the webpage", param_model=WebpageInfo) +@controller.action('Go to the webpage', param_model=WebpageInfo) def go_to_webpage(webpage_info: WebpageInfo): - return webpage_info.link - + return webpage_info.link async def main(): - task = ( - 'Go to the Greece MFA webpage via the link I provided you.' - 'Check the visa appointment dates. If there is no available date in this month, check the next month.' - 'If there is no available date in both months, tell me there is no available date.' - ) + task = ( + 'Go to the Greece MFA webpage via the link I provided you.' + 'Check the visa appointment dates. If there is no available date in this month, check the next month.' + 'If there is no available date in both months, tell me there is no available date.' + ) - model = ChatOpenAI(model='gpt-4o-mini', api_key=os.getenv('OPENAI_API_KEY')) - agent = Agent(task, model, controller, use_vision=True) - - result = await agent.run() + model = ChatOpenAI(model='gpt-4o-mini', api_key=SecretStr(os.getenv('OPENAI_API_KEY', ''))) + agent = Agent(task, model, controller, use_vision=True) + + result = await agent.run() if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/file_upload.py b/examples/file_upload.py index 55576ed63..103423958 100644 --- a/examples/file_upload.py +++ b/examples/file_upload.py @@ -7,9 +7,6 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import asyncio from langchain_openai import ChatOpenAI -from selenium.webdriver.common.action_chains import ActionChains -from selenium.webdriver.common.by import By -from selenium.webdriver.common.keys import Keys from browser_use.agent.service import Agent from browser_use.browser.service import Browser @@ -28,16 +25,20 @@ def ask_human(question: str) -> str: 'Upload file - the file name is inside the function - you only need to call this with the correct index', requires_browser=True, ) -def upload_file(index: int, browser: Browser): - element = browser.get_element(index) +async def upload_file(index: int, browser: Browser): + element = await browser.get_element_by_index(index) my_file = Path.cwd() / 'examples/test_cv.txt' - element.send_keys(str(my_file.absolute())) + if not element: + raise Exception(f'Element with index {index} not found') + + await element.set_input_files(str(my_file.absolute())) return f'Uploaded file to index {index}' @controller.action('Close file dialog', requires_browser=True) -def close_file_dialog(browser: Browser): - ActionChains(browser._get_driver()).send_keys(Keys.ESCAPE).perform() +async def close_file_dialog(browser: Browser): + page = await browser.get_current_page() + await page.keyboard.press('Escape') async def main(): diff --git a/examples/find_and_apply_to_jobs.py b/examples/find_and_apply_to_jobs.py index e5f2f265e..d67693eff 100644 --- a/examples/find_and_apply_to_jobs.py +++ b/examples/find_and_apply_to_jobs.py @@ -1,7 +1,14 @@ +""" +Find and apply to jobs. + +@dev You need to add OPENAI_API_KEY to your environment variables. + +Also you have to install PyPDF2: pip install PyPDF2 +""" + import csv import os import sys -import time from pathlib import Path from PyPDF2 import PdfReader @@ -13,9 +20,6 @@ from typing import List, Optional from langchain_openai import ChatOpenAI from pydantic import BaseModel -from selenium.webdriver.common.action_chains import ActionChains -from selenium.webdriver.common.by import By -from selenium.webdriver.common.keys import Keys from browser_use.agent.service import Agent from browser_use.browser.service import Browser @@ -65,17 +69,21 @@ def read_cv(): @controller.action('Upload cv to index', requires_browser=True) -def upload_cv(index: int, browser: Browser): - close_file_dialog(browser) - element = browser.get_element(index) +async def upload_cv(index: int, browser: Browser): + await close_file_dialog(browser) + element = await browser.get_element_by_index(index) my_cv = Path.cwd() / 'your_cv.pdf' - element.send_keys(str(my_cv.absolute())) + if not element: + raise Exception(f'Element with index {index} not found') + + await element.set_input_files(str(my_cv.absolute())) return f'Uploaded cv to index {index}' @controller.action('Close file dialog', requires_browser=True) -def close_file_dialog(browser: Browser): - ActionChains(browser._get_driver()).send_keys(Keys.ESCAPE).perform() +async def close_file_dialog(browser: Browser): + page = await browser.get_current_page() + await page.keyboard.press('Escape') async def main(): diff --git a/examples/who_starred_my_repo.py b/examples/who_starred_my_repo.py index dc4fcaf1d..385ae4bd3 100644 --- a/examples/who_starred_my_repo.py +++ b/examples/who_starred_my_repo.py @@ -48,8 +48,9 @@ class PageSaver(BaseModel): @controller.action('Save current page info', param_model=PageSaver, requires_browser=True) -def save_page_info(params: PageSaver, browser: Browser): - state = browser.get_state() +async def save_page_info(params: PageSaver, browser: Browser): + session = await browser.get_session() + state = session.cached_state with open(params.filename, 'w') as f: f.write(f'URL: {state.url}\n') f.write(f'Title: {state.title}\n') diff --git a/pyproject.toml b/pyproject.toml index f893b42d2..5bb1095af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description = "Make websites accessible for AI agents" authors = [ { name = "Gregor Zunic" } ] -version = "0.1.6" +version = "0.1.7" readme = "README.md" requires-python = ">=3.11" classifiers = [ @@ -14,7 +14,6 @@ classifiers = [ ] dependencies = [ "MainContentExtractor>=0.0.4", - "Selenium-Screenshot>=2.1.0", "beautifulsoup4>=4.12.3", "langchain>=0.3.7", "langchain-openai>=0.2.5", @@ -23,9 +22,9 @@ dependencies = [ "pydantic>=2.9.2", "python-dotenv>=1.0.1", "requests>=2.32.3", - "selenium>=4.26.1", "webdriver-manager>=4.0.2", - "posthog>=3.7.0" + "posthog>=3.7.0", + "playwright>=1.48.0" ] [project.optional-dependencies] diff --git a/tests/test_agent_actions.py b/tests/test_agent_actions.py index 699491938..79721ffff 100644 --- a/tests/test_agent_actions.py +++ b/tests/test_agent_actions.py @@ -1,7 +1,4 @@ -import asyncio - import pytest -from langchain_anthropic import ChatAnthropic from langchain_openai import ChatOpenAI from pydantic import BaseModel @@ -27,10 +24,10 @@ async def agent_with_controller(): yield controller finally: if controller.browser: - controller.browser.close(force=True) + await controller.browser.close(force=True) -@pytest.mark.asyncio +# @pytest.mark.asyncio async def test_ecommerce_interaction(llm, agent_with_controller): """Test complex ecommerce interaction sequence""" agent = Agent( @@ -73,7 +70,7 @@ async def test_ecommerce_interaction(llm, agent_with_controller): assert 'input_exact_correct' in action_sequence or 'correct_in_input' in action_sequence -@pytest.mark.asyncio +# @pytest.mark.asyncio async def test_error_recovery(llm, agent_with_controller): """Test agent's ability to recover from errors""" agent = Agent( @@ -98,7 +95,7 @@ async def test_error_recovery(llm, agent_with_controller): assert recovery_action is not None -@pytest.mark.asyncio +# @pytest.mark.asyncio async def test_find_contact_email(llm, agent_with_controller): """Test agent's ability to find contact email on a website""" agent = Agent( @@ -121,7 +118,7 @@ async def test_find_contact_email(llm, agent_with_controller): assert email_action is not None -@pytest.mark.asyncio +# @pytest.mark.asyncio async def test_agent_finds_installation_command(llm, agent_with_controller): """Test agent's ability to find the pip installation command for browser-use on the web""" agent = Agent( diff --git a/tests/test_core_functionality.py b/tests/test_core_functionality.py index 903af072b..e2bfb10bb 100644 --- a/tests/test_core_functionality.py +++ b/tests/test_core_functionality.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from langchain_openai import ChatOpenAI @@ -16,12 +14,12 @@ def llm(): @pytest.fixture async def controller(): """Initialize the controller""" - controller = Controller() + controller = Controller(keep_open=False) try: yield controller finally: if controller.browser: - controller.browser.close(force=True) + await controller.browser.close(force=True) @pytest.mark.asyncio @@ -168,15 +166,15 @@ async def test_scroll_down(llm, controller): ) # Get the browser instance browser = controller.browser - driver = browser._get_driver() + page = await browser.get_current_page() # Navigate to the page and get initial scroll position await agent.run(max_steps=1) - initial_scroll_position = driver.execute_script('return window.pageYOffset;') + initial_scroll_position = await page.evaluate('window.scrollY;') # Perform the scroll down action await agent.run(max_steps=2) - final_scroll_position = driver.execute_script('return window.pageYOffset;') + final_scroll_position = await page.evaluate('window.scrollY;') # Validate that the scroll position has changed assert final_scroll_position > initial_scroll_position, 'Page did not scroll down' diff --git a/tests/test_mind2web.py b/tests/test_mind2web.py index 5603b8fa9..01b38eec1 100644 --- a/tests/test_mind2web.py +++ b/tests/test_mind2web.py @@ -3,7 +3,6 @@ Test browser automation using Mind2Web dataset tasks with pytest framework. """ import json -import logging import os from typing import Any, Dict, List @@ -47,7 +46,7 @@ async def controller(): yield controller finally: if controller.browser: - controller.browser.close(force=True) + await controller.browser.close(force=True) # run with: pytest -s -v tests/test_mind2web.py:test_random_samples diff --git a/tests/test_self_registered_actions.py b/tests/test_self_registered_actions.py index 23341dfa0..4ab7b1be7 100644 --- a/tests/test_self_registered_actions.py +++ b/tests/test_self_registered_actions.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from langchain_openai import ChatOpenAI from pydantic import BaseModel @@ -67,9 +65,10 @@ async def controller(): yield controller finally: if controller.browser: - controller.browser.close(force=True) + await controller.browser.close(force=True) +# @pytest.mark.skip(reason="Skipping test for now") @pytest.mark.asyncio async def test_self_registered_actions_no_pydantic(llm, controller): """Test self-registered actions with individual arguments""" @@ -88,13 +87,15 @@ async def test_self_registered_actions_no_pydantic(llm, controller): assert 'concatenate_strings' in action_names +# @pytest.mark.skip(reason="Skipping test for now") @pytest.mark.asyncio async def test_mixed_arguments_actions(llm, controller): """Test actions with mixed argument types""" # Define another action during the test + # Test for async actions @controller.action('Calculate the area of a rectangle') - def calculate_area(length: float, width: float): + async def calculate_area(length: float, width: float): area = length * width return f'The area is {area}' diff --git a/tests/test_stress.py b/tests/test_stress.py index d55d1fc0e..142e7566e 100644 --- a/tests/test_stress.py +++ b/tests/test_stress.py @@ -1,4 +1,3 @@ -import asyncio import time import pytest @@ -22,7 +21,7 @@ async def controller(): yield controller finally: if controller.browser: - controller.browser.close(force=True) + await controller.browser.close(force=True) # should get rate limited @@ -30,7 +29,7 @@ async def controller(): async def test_open_10_tabs_and_extract_content(llm, controller): """Stress test: Open 10 tabs and extract content""" agent = Agent( - task='Open new tabs with example.com, example.net, example.org, and seven more example sites. Then, extract the content from each.', + task='Open new tabs with example.com, example.net, example.org. Then, extract the content from each.', llm=llm, controller=controller, ) @@ -45,4 +44,4 @@ async def test_open_10_tabs_and_extract_content(llm, controller): errors = [h.result.error for h in history if h.result and h.result.error] assert len(errors) == 0, 'Errors occurred during the test' # check if 10 tabs were opened - assert len(controller.browser.current_state.tabs) >= 10, '10 tabs were not opened' + assert len(controller.browser.current_state.tabs) >= 3, '3 tabs were not opened'