From 8a10f7ee99802190b91fb4e6623bb01cabfb0067 Mon Sep 17 00:00:00 2001 From: Alexis Deprez Date: Fri, 4 Oct 2024 13:45:02 +0200 Subject: [PATCH] Simplify base driver (#612) * simplify driver implementation * feat: move actions execution to node impl * chore: format * fix: allow element xpath selection without filter * fix: allow 3XX http codes * feat: allow run customization though SDK --- .../lavague/drivers/playwright/base.py | 78 +- .../lavague/drivers/selenium/__init__.py | 1 - .../lavague/drivers/selenium/base.py | 965 ++++-------------- .../lavague/drivers/selenium/node.py | 159 +++ .../lavague/drivers/selenium/prompt.py | 139 +++ lavague-sdk/lavague/sdk/agent.py | 17 +- lavague-sdk/lavague/sdk/base_driver.py | 747 -------------- .../lavague/sdk/base_driver/__init__.py | 1 + lavague-sdk/lavague/sdk/base_driver/base.py | 294 ++++++ .../lavague/sdk/base_driver/interaction.py | 59 ++ .../lavague/sdk/base_driver/javascript.py | 319 ++++++ lavague-sdk/lavague/sdk/base_driver/node.py | 69 ++ lavague-sdk/lavague/sdk/client.py | 43 +- lavague-sdk/lavague/sdk/exceptions.py | 10 + lavague-sdk/lavague/sdk/trajectory/base.py | 7 +- lavague-sdk/lavague/sdk/trajectory/model.py | 1 + .../lavague/sdk/utilities/format_utils.py | 3 +- .../lavague/sdk/utilities/version_checker.py | 2 + 18 files changed, 1322 insertions(+), 1592 deletions(-) create mode 100644 lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/node.py create mode 100644 lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/prompt.py delete mode 100644 lavague-sdk/lavague/sdk/base_driver.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/__init__.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/base.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/interaction.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/javascript.py create mode 100644 lavague-sdk/lavague/sdk/base_driver/node.py diff --git a/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py b/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py index 4b9cad8c..4c5bd9ba 100644 --- a/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py +++ b/lavague-integrations/drivers/lavague-drivers-playwright/lavague/drivers/playwright/base.py @@ -1,15 +1,16 @@ -from io import BytesIO import json import os -from PIL import Image from typing import Callable, Optional, Any, Mapping, Dict, List from playwright.sync_api import Page, Locator -from lavague.sdk.base_driver import ( - BaseDriver, +from lavague.sdk.base_driver import BaseDriver +from lavague.sdk.base_driver.interaction import ( + InteractionType, + PossibleInteractionsByXpath, +) + +from lavague.sdk.base_driver.javascript import ( JS_GET_INTERACTIVES, JS_WAIT_DOM_IDLE, - PossibleInteractionsByXpath, - InteractionType, ) import time @@ -42,7 +43,7 @@ def __init__( # Before modifying this function, check if your changes are compatible with code_for_init which parses this code # these imports are necessary as they will be pasted to the output def default_init_code(self) -> Page: - from lavague.sdk.base_driver import JS_SETUP_GET_EVENTS + from lavague.sdk.base_driver.javascript import JS_SETUP_GET_EVENTS try: from playwright.sync_api import sync_playwright @@ -114,72 +115,9 @@ def get_html(self) -> str: def destroy(self) -> None: self.page.close() - def check_visibility(self, xpath: str) -> bool: - try: - locator = self.page.locator(f"xpath={xpath}") - return locator.is_visible() and locator.is_enabled() - except: - return False - def resolve_xpath(self, xpath) -> Locator: return self.page.locator(f"xpath={xpath}") - def get_highlighted_element(self, generated_code: str): - elements = [] - - data = json.loads(generated_code) - if not isinstance(data, List): - data = [data] - for item in data: - action_name = item["action"]["name"] - if action_name != "fail": - xpath = item["action"]["args"]["xpath"] - try: - elem = self.page.locator(f"xpath={xpath}") - elements.append(elem) - except: - pass - - if len(elements) == 0: - raise ValueError("No element found.") - - outputs = [] - for element in elements: - element: Locator - - bounding_box = {} - viewport_size = {} - - self.execute_script( - "arguments[0].setAttribute('style', arguments[1]);", - element, - "border: 2px solid red;", - ) - self.execute_script( - "arguments[0].scrollIntoView({block: 'center'});", element - ) - screenshot = self.get_screenshot_as_png() - - bounding_box["x1"] = element.bounding_box()["x"] - bounding_box["y1"] = element.bounding_box()["y"] - bounding_box["x2"] = bounding_box["x1"] + element.bounding_box()["width"] - bounding_box["y2"] = bounding_box["y1"] + element.bounding_box()["height"] - - viewport_size["width"] = self.execute_script("return window.innerWidth;") - viewport_size["height"] = self.execute_script("return window.innerHeight;") - screenshot = BytesIO(screenshot) - screenshot = Image.open(screenshot) - output = { - "screenshot": screenshot, - "bounding_box": bounding_box, - "viewport_size": viewport_size, - } - outputs.append(output) - return outputs - - def maximize_window(self) -> None: - pass - def exec_code( self, code: str, diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/__init__.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/__init__.py index 1c340f2e..a8e41bbe 100644 --- a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/__init__.py +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/__init__.py @@ -1,2 +1 @@ from lavague.drivers.selenium.base import SeleniumDriver -from lavague.drivers.selenium.base import BrowserbaseRemoteConnection diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py index 9cd741a2..35a1c170 100644 --- a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/base.py @@ -1,376 +1,206 @@ -import re -from typing import Any, Optional, Callable, Mapping, Dict, List -from selenium.webdriver.remote.webdriver import WebDriver -from selenium.webdriver.remote.shadowroot import ShadowRoot -from selenium.webdriver.common.by import By -from selenium.webdriver.common.keys import Keys -from selenium.common.exceptions import ( - NoSuchElementException, - WebDriverException, - ElementClickInterceptedException, - TimeoutException, +import json +import time +from typing import Callable, Dict, List, Optional + +from lavague.drivers.selenium.node import SeleniumNode +from lavague.drivers.selenium.prompt import SELENIUM_PROMPT_TEMPLATE +from lavague.sdk.base_driver import BaseDriver +from lavague.sdk.base_driver.interaction import ( + InteractionType, + PossibleInteractionsByXpath, + ScrollDirection, ) -from selenium.webdriver.support.ui import Select, WebDriverWait -from selenium.webdriver.remote.webelement import WebElement -from selenium.webdriver.common.actions.wheel_input import ScrollOrigin -from lavague.sdk.utilities.format_utils import quote_numeric_yaml_values -from lavague.sdk.base_driver import ( - BaseDriver, +from lavague.sdk.base_driver.javascript import ( + ATTACH_MOVE_LISTENER, JS_GET_INTERACTIVES, - JS_WAIT_DOM_IDLE, JS_GET_SCROLLABLE_PARENT, JS_GET_SHADOW_ROOTS, - PossibleInteractionsByXpath, - ScrollDirection, - InteractionType, - DOMNode, + JS_SETUP_GET_EVENTS, + JS_WAIT_DOM_IDLE, + REMOVE_HIGHLIGHT, + get_highlighter_style, +) +from lavague.sdk.exceptions import ( + CannotBackException, + NoPageException, +) + +from selenium.common.exceptions import ( + NoSuchElementException, + TimeoutException, ) -from lavague.sdk.exceptions import CannotBackException -from PIL import Image -from io import BytesIO +from selenium.webdriver import Chrome from selenium.webdriver.chrome.options import Options from selenium.webdriver.common.action_chains import ActionChains -import time -import yaml -import json -from selenium.webdriver.remote.remote_connection import RemoteConnection -import requests -import os -from lavague.drivers.selenium.javascript import ( - ATTACH_MOVE_LISTENER, - get_highlighter_style, - REMOVE_HIGHLIGHT, -) +from selenium.webdriver.common.actions.wheel_input import ScrollOrigin +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.remote.webelement import WebElement +from selenium.webdriver.support.ui import WebDriverWait -class SeleniumDriver(BaseDriver): +class SeleniumDriver(BaseDriver[SeleniumNode]): driver: WebDriver - last_hover_xpath: Optional[str] = None def __init__( self, - url: Optional[str] = None, - get_selenium_driver: Optional[Callable[[], WebDriver]] = None, + options: Optional[Options] = None, headless: bool = True, user_data_dir: Optional[str] = None, - width: Optional[int] = 1096, - height: Optional[int] = 1096, - options: Optional[Options] = None, - driver: Optional[WebDriver] = None, - log_waiting_time=False, waiting_completion_timeout=10, - remote_connection: Optional["BrowserbaseRemoteConnection"] = None, - ): - self.headless = headless - self.user_data_dir = user_data_dir - self.width = width - self.height = height - self.options = options - self.driver = driver - self.log_waiting_time = log_waiting_time + log_waiting_time=False, + user_agent="Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36", + auto_init=True, + ) -> None: self.waiting_completion_timeout = waiting_completion_timeout - self.remote_connection = remote_connection - super().__init__(url, get_selenium_driver) - - # Default code to init the driver. - # Before making any change to this, make sure it is compatible with code_for_init, which parses the code of this function - # These imports are necessary as they will be pasted to the output - def default_init_code(self) -> Any: - from selenium import webdriver - from selenium.webdriver.common.by import By - from selenium.webdriver.chrome.options import Options - from selenium.webdriver.common.keys import Keys - from selenium.webdriver.common.action_chains import ActionChains - from lavague.sdk.base_driver import JS_SETUP_GET_EVENTS - - if self.options: - chrome_options = self.options + self.log_waiting_time = log_waiting_time + if options: + self.options = options else: - chrome_options = Options() - if self.headless: - chrome_options.add_argument("--headless=new") - if self.user_data_dir: - chrome_options.add_argument(f"--user-data-dir={self.user_data_dir}") - else: - chrome_options.add_argument("--lang=en") - user_agent = "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36" - chrome_options.add_argument(f"user-agent={user_agent}") - chrome_options.add_argument("--no-sandbox") - chrome_options.page_load_strategy = "normal" - # allow access to cross origin iframes - chrome_options.add_argument("--disable-web-security") - chrome_options.add_argument("--disable-site-isolation-trials") - chrome_options.add_argument("--disable-notifications") - chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"}) - - if self.remote_connection: - chrome_options.add_experimental_option("debuggerAddress", "localhost:9223") - self.driver = webdriver.Remote( - self.remote_connection, options=chrome_options - ) - elif self.driver is None: - self.driver = webdriver.Chrome(options=chrome_options) - - # 538: browserbase implementation - move execute_cdp_cmd to inner block to avoid error - # AttributeError: 'WebDriver' object has no attribute 'execute_cdp_cmd' - self.driver.execute_cdp_cmd( - "Page.addScriptToEvaluateOnNewDocument", - {"source": JS_SETUP_GET_EVENTS}, - ) - self.resize_driver(self.width, self.height) - return self.driver - - def __enter__(self): - return self - - def __exit__(self, *args): - self.destroy() + self.options = Options() + if headless: + self.options.add_argument("--headless=new") + self.options.add_argument("--lang=en") + self.options.add_argument(f"user-agent={user_agent}") + self.options.add_argument("--disable-notifications") + if user_data_dir: + self.options.add_argument(f"--user-data-dir={user_data_dir}") + self.options.page_load_strategy = "normal" + self.options.add_argument("--no-sandbox") + self.options.add_argument("--disable-web-security") + self.options.add_argument("--disable-site-isolation-trials") + self.options.set_capability("goog:loggingPrefs", {"performance": "ALL"}) + if auto_init: + self.init() + + def init(self) -> None: + self.driver = Chrome(options=self.options) + self.driver.execute_cdp_cmd( + "Page.addScriptToEvaluateOnNewDocument", + {"source": JS_SETUP_GET_EVENTS}, + ) - def get_driver(self) -> WebDriver: - return self.driver + def destroy(self) -> None: + """Cleanly destroy the underlying driver""" + self.driver.quit() - def resize_driver(self, width, height) -> None: - if width is None and height is None: - return None - # Selenium is only being able to set window size and not viewport size + def resize_driver(self, width: int, height: int): + """Resize the viewport to a targeted height and width""" self.driver.set_window_size(width, height) viewport_height = self.driver.execute_script("return window.innerHeight;") - height_difference = height - viewport_height self.driver.set_window_size(width, height + height_difference) - self.width = width - self.height = height - - def code_for_resize(self, width, height) -> str: - return f""" -driver.set_window_size({width}, {height}) -viewport_height = driver.execute_script("return window.innerHeight;") -height_difference = {height} - viewport_height -driver.set_window_size({width}, {height} + height_difference) -""" - - def get_url(self) -> Optional[str]: + + def get_url(self) -> str: + """Get the url of the current page, raise NoPageException if no page is loaded""" if self.driver.current_url == "data:,": - return None + raise NoPageException() return self.driver.current_url - def code_for_get(self, url: str) -> str: - return f'driver.get("{url}")' - def get(self, url: str) -> None: + """Navigate to the url""" self.driver.get(url) def back(self) -> None: - if self.driver.execute_script("return !document.referrer"): + """Navigate back, raise CannotBackException if history root is reached""" + if self.driver.execute_script("return !document.referrer;"): raise CannotBackException() self.driver.back() - def code_for_back(self) -> None: - return "driver.back()" - def get_html(self) -> str: + """ + Returns the HTML of the current page. + If clean is True, We remove unnecessary tags and attributes from the HTML. + Clean HTMLs are easier to process for the LLM. + """ return self.driver.page_source - def get_screenshot_as_png(self) -> bytes: - return self.driver.get_screenshot_as_png() - - def destroy(self) -> None: - self.driver.quit() - - def maximize_window(self) -> None: - self.driver.maximize_window() + def get_tabs(self) -> str: + """Return description of the tabs opened with the current tab being focused. - def check_visibility(self, xpath: str) -> bool: - try: - # Done manually here to avoid issues - element = self.resolve_xpath(xpath).element - res = ( - element is not None and element.is_displayed() and element.is_enabled() - ) - self.switch_default_frame() - return res - except: - return False + Example of output: + Tabs opened: + 0 - Overview - OpenAI API + 1 - [CURRENT] Nos destinations Train - SNCF Connect + """ + window_handles = self.driver.window_handles + # Store the current window handle (focused tab) + current_handle = self.driver.current_window_handle + tab_info = [] + tab_id = 0 - def get_viewport_size(self) -> dict: - viewport_size = {} - viewport_size["width"] = self.execute_script("return window.innerWidth;") - viewport_size["height"] = self.execute_script("return window.innerHeight;") - return viewport_size + for handle in window_handles: + # Switch to each tab + self.driver.switch_to.window(handle) - def get_highlighted_element(self, generated_code: str): - elements = [] - - # Ensures that numeric values are quoted - generated_code = quote_numeric_yaml_values(generated_code) - - data = yaml.safe_load(generated_code) - if not isinstance(data, List): - data = [data] - for item in data: - for action in item["actions"]: - try: - xpath = action["action"]["args"]["xpath"] - elem = self.driver.find_element(By.XPATH, xpath) - elements.append(elem) - except: - pass - - outputs = [] - for element in elements: - element: WebElement - - bounding_box = {} - - self.execute_script( - "arguments[0].setAttribute('style', arguments[1]);", - element, - "border: 2px solid red;", - ) - self.execute_script( - "arguments[0].scrollIntoView({block: 'center'});", element - ) - screenshot = self.get_screenshot_as_png() - - bounding_box["x1"] = element.location["x"] - bounding_box["y1"] = element.location["y"] - bounding_box["x2"] = bounding_box["x1"] + element.size["width"] - bounding_box["y2"] = bounding_box["y1"] + element.size["height"] - - screenshot = BytesIO(screenshot) - screenshot = Image.open(screenshot) - output = { - "screenshot": screenshot, - "bounding_box": bounding_box, - "viewport_size": self.get_viewport_size(), - } - outputs.append(output) - return outputs - - def switch_frame(self, xpath): - iframe = self.driver.find_element(By.XPATH, xpath) - self.driver.switch_to.frame(iframe) + # Get the title of the current tab + title = self.driver.title - def switch_default_frame(self) -> None: - self.driver.switch_to.default_content() + # Check if this is the focused tab + if handle == current_handle: + tab_info.append(f"{tab_id} - [CURRENT] {title}") + else: + tab_info.append(f"{tab_id} - {title}") - def switch_parent_frame(self) -> None: - self.driver.switch_to.parent_frame() + tab_id += 1 - def resolve_xpath(self, xpath: Optional[str]) -> "SeleniumNode": - return SeleniumNode(xpath, self) + # Switch back to the original tab + self.driver.switch_to.window(current_handle) - def exec_code( - self, - code: str, - globals: dict[str, Any] = None, - locals: Mapping[str, object] = None, - ): - # Ensures that numeric values are quoted to avoid issues with YAML parsing - code = quote_numeric_yaml_values(code) - - data = yaml.safe_load(code) - if not isinstance(data, List): - data = [data] - for item in data: - for action in item["actions"]: - action_name = action["action"]["name"] - args = action["action"]["args"] - xpath = args.get("xpath", None) - - match action_name: - case "click": - self.click(xpath) - case "setValue": - self.set_value(xpath, args["value"]) - case "setValueAndEnter": - self.set_value(xpath, args["value"], True) - case "dropdownSelect": - self.dropdown_select(xpath, args["value"]) - case "hover": - self.hover(xpath) - case "scroll": - self.scroll( - xpath, - ScrollDirection.from_string(args.get("value", "DOWN")), - ) - case _: - raise ValueError(f"Unknown action: {action_name}") - - self.wait_for_idle() - - def execute_script(self, js_code: str, *args) -> Any: - return self.driver.execute_script(js_code, *args) - - def scroll_up(self): - self.scroll(direction=ScrollDirection.UP) - - def scroll_down(self): - self.scroll(direction=ScrollDirection.DOWN) - - def code_for_execute_script(self, js_code: str, *args) -> str: - return ( - f"driver.execute_script({js_code}, {', '.join(str(arg) for arg in args)})" - ) + tab_info = "\n".join(tab_info) + tab_info = "Tabs opened:\n" + tab_info + return tab_info - def hover(self, xpath: str): - with self.resolve_xpath(xpath) as element_resolved: - self.last_hover_xpath = xpath - ActionChains(self.driver).move_to_element( - element_resolved.element - ).perform() + def switch_tab(self, tab_id: int) -> None: + """Switch to the tab with the given id""" + window_handles = self.driver.window_handles + self.driver.switch_to.window(window_handles[tab_id]) - def scroll_page(self, direction: ScrollDirection = ScrollDirection.DOWN): - self.driver.execute_script(direction.get_page_script()) + def resolve_xpath(self, xpath: str): + """ + Return the element for the corresponding xpath, the underlying driver may switch iframe if necessary + """ + return SeleniumNode(self.driver, xpath) - def get_scroll_anchor(self, xpath_anchor: Optional[str] = None) -> WebElement: - with self.resolve_xpath( - xpath_anchor or self.last_hover_xpath - ) as element_resolved: - element = element_resolved.element - parent = self.driver.execute_script(JS_GET_SCROLLABLE_PARENT, element) - scroll_anchor = parent or element - return scroll_anchor + def get_viewport_size(self) -> dict: + """Return viewport size as {"width": int, "height": int}""" + viewport_size = {} + viewport_size["width"] = self.driver.execute_script("return window.innerWidth;") + viewport_size["height"] = self.driver.execute_script( + "return window.innerHeight;" + ) + return viewport_size - def get_scroll_container_size(self, scroll_anchor: WebElement): - container = self.driver.execute_script(JS_GET_SCROLLABLE_PARENT, scroll_anchor) - if container: - return ( - self.driver.execute_script( - "const r = arguments[0].getBoundingClientRect(); return [r.width, r.height]", - scroll_anchor, - ), - True, - ) - return ( - self.driver.execute_script( - "return [window.innerWidth, window.innerHeight]", - ), + def get_possible_interactions( + self, + in_viewport=True, + foreground_only=True, + types: List[InteractionType] = [ + InteractionType.CLICK, + InteractionType.TYPE, + InteractionType.HOVER, + ], + ) -> PossibleInteractionsByXpath: + """Get elements that can be interacted with as a dictionary mapped by xpath""" + exe: Dict[str, List[str]] = self.driver.execute_script( + JS_GET_INTERACTIVES, + in_viewport, + foreground_only, False, + [t.name for t in types], ) + res = dict() + for k, v in exe.items(): + res[k] = set(InteractionType[i] for i in v) + return res - def is_bottom_of_page(self) -> bool: - return not self.can_scroll(direction=ScrollDirection.DOWN) - - def can_scroll( - self, - xpath_anchor: Optional[str] = None, - direction: ScrollDirection = ScrollDirection.DOWN, - ) -> bool: - try: - scroll_anchor = self.get_scroll_anchor(xpath_anchor) - if scroll_anchor: - return self.driver.execute_script( - direction.get_script_element_is_scrollable(), - scroll_anchor, - ) - except NoSuchElementException: - pass - return self.driver.execute_script(direction.get_script_page_is_scrollable()) + def scroll_into_view(self, xpath: str): + with self.resolve_xpath(xpath) as node: + self.driver.execute_script("arguments[0].scrollIntoView()", node.element) def scroll( self, - xpath_anchor: Optional[str] = None, + xpath_anchor: Optional[str] = "/html/body", direction: ScrollDirection = ScrollDirection.DOWN, scroll_factor=0.75, ): @@ -391,91 +221,38 @@ def scroll( ActionChains(self.driver).scroll_by_amount( scroll_xy[0], scroll_xy[1] ).perform() - if xpath_anchor: - self.last_hover_xpath = xpath_anchor except NoSuchElementException: self.scroll_page(direction) - def click(self, xpath: str): - with self.resolve_xpath(xpath) as element_resolved: - element = element_resolved.element - self.last_hover_xpath = xpath - try: - element.click() - except ElementClickInterceptedException: - try: - # Move to the element and click at its position - ActionChains(self.driver).move_to_element(element).click().perform() - except WebDriverException as click_error: - raise Exception( - f"Failed to click at element coordinates of {xpath} : {str(click_error)}" - ) - except Exception as e: - import traceback - - traceback.print_exc() - raise Exception( - f"An unexpected error occurred when trying to click on {xpath}: {str(e)}" - ) - - def set_value(self, xpath: str, value: str, enter: bool = False): - with self.resolve_xpath(xpath) as element_resolved: - elem = element_resolved.element - try: - self.last_hover_xpath = xpath - if elem.tag_name == "select": - # use the dropdown_select to set the value of a select - return self.dropdown_select(xpath, value) - if elem.tag_name == "input" and elem.get_attribute("type") == "file": - # set the value of a file input - return self.upload_file(xpath, value) - - elem.clear() - except: - # might not be a clearable element, but global click + send keys can still success - pass - - self.click(xpath) - - ( - ActionChains(self.driver) - .key_down(Keys.CONTROL) - .send_keys("a") - .key_up(Keys.CONTROL) - .send_keys(Keys.DELETE) # clear the input field - .send_keys(value) - .perform() - ) - if enter: - ActionChains(self.driver).send_keys(Keys.ENTER).perform() + def scroll_page(self, direction: ScrollDirection = ScrollDirection.DOWN): + self.driver.execute_script(direction.get_page_script()) - def dropdown_select(self, xpath: str, value: str): - with self.resolve_xpath(xpath) as element_resolved: - element = element_resolved.element - self.last_hover_xpath = xpath - - if element.tag_name != "select": - print( - f"Cannot use dropdown_select on {element.tag_name}, falling back to simple click on {xpath}" - ) - return self.click(xpath) - - select = Select(element) - try: - select.select_by_value(value) - except NoSuchElementException: - select.select_by_visible_text(value) - - def upload_file(self, xpath: str, file_path: str): - with self.resolve_xpath(xpath) as element_resolved: + def get_scroll_anchor(self, xpath_anchor: Optional[str] = None) -> WebElement: + with self.resolve_xpath(xpath_anchor or "/html/body") as element_resolved: element = element_resolved.element - self.last_hover_xpath = xpath - element.send_keys(file_path) + parent = self.driver.execute_script(JS_GET_SCROLLABLE_PARENT, element) + scroll_anchor = parent or element + return scroll_anchor - def perform_wait(self, duration: float): - import time + def get_scroll_container_size(self, scroll_anchor: WebElement): + container = self.driver.execute_script(JS_GET_SCROLLABLE_PARENT, scroll_anchor) + if container: + return ( + self.driver.execute_script( + "const r = arguments[0].getBoundingClientRect(); return [r.width, r.height]", + scroll_anchor, + ), + True, + ) + return ( + self.driver.execute_script( + "return [window.innerWidth, window.innerHeight]", + ), + False, + ) - time.sleep(duration) + def wait_for_dom_stable(self, timeout: float = 10): + self.driver.execute_script(JS_WAIT_DOM_IDLE, max(0, round(timeout * 1000))) def is_idle(self): active = 0 @@ -503,9 +280,6 @@ def is_idle(self): return len(request_ids) == 0 and active <= 0 - def wait_for_dom_stable(self, timeout: float = 10): - self.driver.execute_script(JS_WAIT_DOM_IDLE, max(0, round(timeout * 1000))) - def wait_for_idle(self): t = time.time() elapsed = 0 @@ -525,49 +299,39 @@ def wait_for_idle(self): ) def get_capability(self) -> str: + """Prompt to explain the llm which style of code he should output and which variables and imports he should expect""" return SELENIUM_PROMPT_TEMPLATE - def get_tabs(self): - driver = self.driver - window_handles = driver.window_handles - # Store the current window handle (focused tab) - current_handle = driver.current_window_handle - tab_info = [] - tab_id = 0 - - for handle in window_handles: - # Switch to each tab - driver.switch_to.window(handle) - - # Get the title of the current tab - title = driver.title - - # Check if this is the focused tab - if handle == current_handle: - tab_info.append(f"{tab_id} - [CURRENT] {title}") - else: - tab_info.append(f"{tab_id} - {title}") - - tab_id += 1 - - # Switch back to the original tab - driver.switch_to.window(current_handle) + def get_screenshot_as_png(self) -> bytes: + return self.driver.get_screenshot_as_png() - tab_info = "\n".join(tab_info) - tab_info = "Tabs opened:\n" + tab_info - return tab_info + def get_shadow_roots(self) -> Dict[str, str]: + """Return a dictionary of shadow roots HTML by xpath""" + return self.driver.execute_script(JS_GET_SHADOW_ROOTS) - def switch_tab(self, tab_id: int): - driver = self.driver - window_handles = driver.window_handles + def get_nodes(self, xpaths: List[str]) -> List[SeleniumNode]: + return [SeleniumNode(self.driver, xpath) for xpath in xpaths] - # Switch to the tab with the given id - driver.switch_to.window(window_handles[tab_id]) + def highlight_nodes( + self, xpaths: List[str], color: str = "red", label=False + ) -> Callable: + nodes = self.get_nodes(xpaths) + self.driver.execute_script(ATTACH_MOVE_LISTENER) + set_style = get_highlighter_style(color, label) + self.exec_script_for_nodes( + nodes, "arguments[0].forEach((a, i) => { " + set_style + "})" + ) + return self._add_highlighted_destructors( + lambda: self.remove_nodes_highlight(xpaths) + ) - def get_nodes(self, xpaths: List[str]) -> List["SeleniumNode"]: - return [SeleniumNode(xpath, self) for xpath in xpaths] + def remove_nodes_highlight(self, xpaths: List[str]): + self.exec_script_for_nodes( + self.get_nodes(xpaths), + REMOVE_HIGHLIGHT, + ) - def exec_script_for_nodes(self, nodes: List["SeleniumNode"], script: str): + def exec_script_for_nodes(self, nodes: List[SeleniumNode], script: str): standard_nodes: List[SeleniumNode] = [] special_nodes: List[SeleniumNode] = [] @@ -595,328 +359,11 @@ def exec_script_for_nodes(self, nodes: List["SeleniumNode"], script: str): script, [n.element], ) - self.switch_default_frame() - - def remove_nodes_highlight(self, xpaths: List[str]): - self.exec_script_for_nodes( - self.get_nodes(xpaths), - REMOVE_HIGHLIGHT, - ) - - def highlight_nodes( - self, xpaths: List[str], color: str = "red", label=False - ) -> Callable: - nodes = self.get_nodes(xpaths) - self.driver.execute_script(ATTACH_MOVE_LISTENER) - set_style = get_highlighter_style(color, label) - self.exec_script_for_nodes( - nodes, "arguments[0].forEach((a, i) => { " + set_style + "})" - ) - return self._add_highlighted_destructors( - lambda: self.remove_nodes_highlight(xpaths) - ) + self.driver.switch_to.default_content() - def get_possible_interactions( - self, - in_viewport=True, - foreground_only=True, - types: List[InteractionType] = [ - InteractionType.CLICK, - InteractionType.TYPE, - InteractionType.HOVER, - ], - ) -> PossibleInteractionsByXpath: - exe: Dict[str, List[str]] = self.driver.execute_script( - JS_GET_INTERACTIVES, - in_viewport, - foreground_only, - False, - [t.name for t in types], - ) - res = dict() - for k, v in exe.items(): - res[k] = set(InteractionType[i] for i in v) - return res - - def get_in_viewport(self): - res: Dict[str, List[str]] = self.driver.execute_script( - JS_GET_INTERACTIVES, - True, - True, - True, - ) - return list(res.keys()) - - def get_shadow_roots(self) -> Dict[str, str]: - return self.driver.execute_script(JS_GET_SHADOW_ROOTS) - - -class SeleniumNode(DOMNode): - def __init__( - self, - xpath: Optional[str], - driver: SeleniumDriver, - element: Optional[WebElement] = None, - ) -> None: - if not xpath: - raise NoSuchElementException("xpath is missing") - self.xpath = xpath - self._driver = driver - if element: - self._element = element - super().__init__() - - @property - def element(self) -> Optional[WebElement]: - if not hasattr(self, "_element"): - print("WARN: DOMNode context manager missing") - self.__enter__() - return self._element - - @property - def value(self) -> Any: - elem = self.element - return elem.get_attribute("value") if elem else None - - def highlight(self, color: str = "red", bounding_box=True): - self._driver.highlight_nodes([self.xpath], color, bounding_box) - return self - - def clear(self): - self._driver.remove_nodes_highlight([self.xpath]) - return self - - def take_screenshot(self): - with self: - if self.element: - try: - return Image.open(BytesIO(self.element.screenshot_as_png)) - except WebDriverException: - pass - return Image.new("RGB", (0, 0)) - - def get_html(self): - with self: - return self._driver.driver.execute_script( - "return arguments[0].outerHTML", self.element - ) - - def __enter__(self): - if hasattr(self, "_element"): - return self - - self._element = None - if not self.xpath: - return self - - root = self._driver.driver - local_xpath = self.xpath - - def find_element(xpath): - try: - if isinstance(root, ShadowRoot): - # Shadow root does not support find_element with xpath - css_selector = re.sub( - r"\[([0-9]+)\]", - r":nth-of-type(\1)", - xpath[1:].replace("/", " > "), - ) - return root.find_element(By.CSS_SELECTOR, css_selector) - return root.find_element(By.XPATH, xpath) - except Exception: - return None - - while local_xpath: - match = re.search(r"/iframe|//", local_xpath) - - if match: - before, sep, local_xpath = local_xpath.partition(match.group()) - if sep == "/iframe": - self._driver.switch_frame(before + sep) - elif sep == "//": - custom_element = find_element(before) - if not custom_element: - break - root = custom_element.shadow_root - local_xpath = "/" + local_xpath - else: - break - - self._element = find_element(local_xpath) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if hasattr(self, "_element"): - self._driver.switch_default_frame() - del self._element - - -class BrowserbaseRemoteConnection(RemoteConnection): - _session_id = None + def switch_frame(self, xpath: str) -> None: + iframe = self.driver.find_element(By.XPATH, xpath) + self.driver.switch_to.frame(iframe) - def __init__( - self, - remote_server_addr: str, - api_key: Optional[str] = None, - project_id: Optional[str] = None, - ): - super().__init__(remote_server_addr) - self.api_key = api_key or os.environ["BROWSERBASE_API_KEY"] - self.project_id = project_id or os.environ["BROWSERBASE_PROJECT_ID"] - - def get_remote_connection_headers(self, parsed_url, keep_alive=False): - if self._session_id is None: - self._session_id = self._create_session() - headers = super().get_remote_connection_headers(parsed_url, keep_alive) - headers.update({"x-bb-api-key": self.api_key}) - headers.update({"session-id": self._session_id}) - return headers - - def _create_session(self): - url = "https://www.browserbase.com/v1/sessions" - headers = {"Content-Type": "application/json", "x-bb-api-key": self.api_key} - response = requests.post( - url, json={"projectId": self.project_id}, headers=headers - ) - return response.json()["id"] - - -SELENIUM_PROMPT_TEMPLATE = """ -You are a chrome extension and your goal is to interact with web pages. You have been given a series of HTML snippets and queries. -Your goal is to return a list of actions that should be done in order to execute the actions. -Always target elements by using the full XPATH. You can only use one of the Xpaths included in the HTML. Do not derive new Xpaths. - -Your response must always be in the YAML format with the yaml markdown indicator and must include the main item "actions" , which will contains the objects "action", which contains the string "name" of tool of choice, and necessary arguments ("args") if required by the tool. -There must be only ONE args sub-object, such as args (if the tool has multiple arguments). -You must always include the comments as well, describing your actions step by step, following strictly the format in the examples provided. - -Provide high level explanations about why you think this element is the right one. -Your answer must be short and concise. Always includes comments in the YAML before listing the actions. - -The actions available are: - -Name: click -Description: Click on an element with a specific xpath -Arguments: - - xpath (string) - -Name: setValue -Description: Focus on and set the value of an input element with a specific xpath -Arguments: - - xpath (string) - - value (string) - -Name: dropdownSelect -Description: Select an option from a dropdown menu by its value -Arguments: - - xpath (string) - - value (string) - -Name: setValueAndEnter -Description: Like "setValue", except then it presses ENTER. Use this tool can submit the form when there's no "submit" button. -Arguments: - - xpath (string) - - value (string) - -Name: hover -Description: Move the mouse cursor over an element identified by the given xpath. It can be used to reveal tooltips or dropdown that appear on hover. It can also be used before scrolling to ensure the focus is in the correct container before performing the scroll action. -Arguments: - - xpath (string) - -Name: scroll -Description: Scroll the container that holds the element identified by the given xpath -Arguments: - - xpath (string) - - value (string): UP or DOWN - -Here are examples of previous answers: -HTML: -
Check in / Check out
-
-Query: Click on 'Home in Ploubazlanec' -Authorized Xpaths: "{'/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div/span[2]', '/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div', '/html/body/div[5]/d iv/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div', '/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div/span[2]/button/div', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div/div/div/div/div/div[2]/div/div/div/div/a/div', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div', '/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div/span[2]/button[2]', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div/div/div/div/div/div[2]/div/div/div/div', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div/div', '/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div/span[2]/button'}" -Completion: -```yaml -# Let's think through this step-by-step: -# 1. The query asks us to click on 'Home in Ploubazlanec' -# 2. In the HTML, we need to find an element that represents this listing -# 3. We can see a div with the text "Home in Ploubazlanec" in the title -# 4. The parent element of this div is an anchor tag, which is likely the clickable link for the listing -# 5. We should use the XPath of this anchor tag to perform the click action - -- actions: - - action: - # Click on the anchor tag that contains the listing title - args: - xpath: "/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div/div/div/div/div/div[2]/div/div/div/div/a" - name: "click" -``` ------ -HTML: -
-
- -More -
- - -Authorized Xpaths: "{'/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/a', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div/tab[1]/a', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div/tab[1]', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]'}" -Query: Click on "Gemma" under the "More" dropdown menu. -Completion: -```yaml -# Let's think step by step -# First, we notice that the query asks us to click on the "Gemma" option under the "More" dropdown menu. -# In the provided HTML, we see that the "More" dropdown menu is within a tab element with a specific class and role attribute. -# The "More" dropdown menu can be identified by its class 'devsite-overflow-tab' and contains a link element with the text 'More'. -# We need to interact with this dropdown menu to reveal the hidden options. -# Specifically, for the "More" dropdown menu, there is an anchor element within a tab element: -# /html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/a - -- actions: - - action: - # We can use this XPATH to identify and click on the "More" dropdown menu: - args: - xpath: "/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/a" - value: "" - name: "click" - - action: - # After clicking the "More" dropdown, we need to select the "Gemma" option from the revealed menu. - # The "Gemma" option is located within the dropdown menu and can be identified by its anchor element with the corresponding text: - # /html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div/tab[1]/a - # Thus, we use this XPATH to identify and click on the "Gemma" option: - args: - xpath: "/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div/tab[1]/a" - value: "" - name: "click" -``` ------ -HTML: - -Authorized Xpaths: "{'/html/body/div/main/form/section/div/select'}" -Query: Select the 2:00 AM - 3:00 AM option from the dropdown menu -Completion: -```yaml -# Let's think step by step -# The query asks us to select the "2:00 AM - 3:00 AM" option from a dropdown menu. -# We need to identify the correct option within the dropdown menu based on its value attribute. -# The dropdown menu is specified by its XPATH, and the value of the option we need to select is "2". -# We can use the following "select" XPATH to locate the dropdown menu and the value "2" to select the appropriate option: -# /html/body/div/main/form/section/div/select - -- actions: - - action: - # Select the "3:00 AM - 4:00 AM" option by targeting the dropdown menu with the specified XPATH. - args: - xpath: "/html/body/div/main/form/section/div/select" - value: "2" - name: "dropdownSelect" -``` -""" + def switch_parent_frame(self) -> None: + self.driver.switch_to.parent_frame() diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/node.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/node.py new file mode 100644 index 00000000..fae9f701 --- /dev/null +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/node.py @@ -0,0 +1,159 @@ +import re +from io import BytesIO +from typing import Optional + +from lavague.sdk.base_driver import DOMNode +from lavague.sdk.exceptions import ElementNotFoundException +from PIL import Image + +from selenium.common.exceptions import ( + ElementClickInterceptedException, + NoSuchElementException, + WebDriverException, +) +from selenium.webdriver.common.action_chains import ActionChains +from selenium.webdriver.common.by import By +from selenium.webdriver.common.keys import Keys +from selenium.webdriver.remote.shadowroot import ShadowRoot +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.remote.webelement import WebElement +from selenium.webdriver.support.ui import Select + + +class SeleniumNode(DOMNode[WebElement]): + def __init__( + self, + driver: WebDriver, + xpath: str, + element: Optional[WebElement] = None, + ) -> None: + self.driver = driver + self.xpath = xpath + if element: + self._element = element + super().__init__() + + @property + def element(self) -> WebElement: + if not hasattr(self, "_element"): + print("WARN: DOMNode context manager missing") + self.__enter__() + if self._element is None: + raise ElementNotFoundException(self.xpath) + return self._element + + @property + def value(self) -> Optional[str]: + return self.element.get_attribute("value") + + @property + def text(self) -> str: + return self.element.text + + @property + def outer_html(self) -> str: + return self.driver.execute_script("return arguments[0].outerHTML", self.element) + + @property + def inner_html(self) -> str: + return self.driver.execute_script("return arguments[0].innerHTML", self.element) + + def take_screenshot(self): + with self: + if self.element: + try: + return Image.open(BytesIO(self.element.screenshot_as_png)) + except WebDriverException: + pass + return Image.new("RGB", (0, 0)) + + def click(self): + with self: + try: + self.element.click() + except ElementClickInterceptedException: + try: + # Move to the element and click at its position + ActionChains(self.driver).move_to_element( + self.element + ).click().perform() + except Exception as click_error: + raise Exception( + f"Failed to click at element coordinates of {self.xpath} : {str(click_error)}" + ) + + def set_value(self, value: str): + with self: + if self.element.tag_name == "input": + try: + self.element.clear() + except WebDriverException: + pass + if self.element.tag_name == "select": + select = Select(self.element) + try: + select.select_by_value(value) + except NoSuchElementException: + select.select_by_visible_text(value) + else: + ( + ActionChains(self.driver) + .key_down(Keys.CONTROL) + .send_keys("a") + .key_up(Keys.CONTROL) + .send_keys(Keys.DELETE) # clear the input field + .send_keys(value) + .perform() + ) + + def hover(self): + with self: + ActionChains(self.driver).move_to_element(self.element).perform() + + def enter_context(self): + if hasattr(self, "_element"): + return + + root = self.driver + local_xpath = self.xpath + + def find_element(xpath): + try: + if isinstance(root, ShadowRoot): + # Shadow root does not support find_element with xpath + css_selector = re.sub( + r"\[([0-9]+)\]", + r":nth-of-type(\1)", + xpath[1:].replace("/", " > "), + ) + return root.find_element(By.CSS_SELECTOR, css_selector) + return root.find_element(By.XPATH, xpath) + except Exception: + return None + + while local_xpath: + match = re.search(r"/iframe|//", local_xpath) + + if match: + before, sep, local_xpath = local_xpath.partition(match.group()) + if sep == "/iframe": + iframe = self.driver.find_element(By.XPATH, before + sep) + self.driver.switch_to.frame(iframe) + elif sep == "//": + custom_element = find_element(before) + if not custom_element: + break + root = custom_element.shadow_root + local_xpath = "/" + local_xpath + else: + break + + self._element = find_element(local_xpath) + + if not self._element: + raise ElementNotFoundException(self.xpath) + + def exit_context(self): + if hasattr(self, "_element"): + self.driver.switch_to.default_content() + del self._element diff --git a/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/prompt.py b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/prompt.py new file mode 100644 index 00000000..4aea401b --- /dev/null +++ b/lavague-integrations/drivers/lavague-drivers-selenium/lavague/drivers/selenium/prompt.py @@ -0,0 +1,139 @@ +SELENIUM_PROMPT_TEMPLATE = """ +You are a chrome extension and your goal is to interact with web pages. You have been given a series of HTML snippets and queries. +Your goal is to return a list of actions that should be done in order to execute the actions. +Always target elements by using the full XPATH. You can only use one of the Xpaths included in the HTML. Do not derive new Xpaths. + +Your response must always be in the YAML format with the yaml markdown indicator and must include the main item "actions" , which will contains the objects "action", which contains the string "name" of tool of choice, and necessary arguments ("args") if required by the tool. +There must be only ONE args sub-object, such as args (if the tool has multiple arguments). +You must always include the comments as well, describing your actions step by step, following strictly the format in the examples provided. + +Provide high level explanations about why you think this element is the right one. +Your answer must be short and concise. Always includes comments in the YAML before listing the actions. + +The actions available are: + +Name: click +Description: Click on an element with a specific xpath +Arguments: + - xpath (string) + +Name: setValue +Description: Focus on and set the value of an input element with a specific xpath +Arguments: + - xpath (string) + - value (string) + +Name: dropdownSelect +Description: Select an option from a dropdown menu by its value +Arguments: + - xpath (string) + - value (string) + +Name: setValueAndEnter +Description: Like "setValue", except then it presses ENTER. Use this tool can submit the form when there's no "submit" button. +Arguments: + - xpath (string) + - value (string) + +Name: hover +Description: Move the mouse cursor over an element identified by the given xpath. It can be used to reveal tooltips or dropdown that appear on hover. It can also be used before scrolling to ensure the focus is in the correct container before performing the scroll action. +Arguments: + - xpath (string) + +Name: scroll +Description: Scroll the container that holds the element identified by the given xpath +Arguments: + - xpath (string) + - value (string): UP or DOWN + +Here are examples of previous answers: +HTML: +
Check in / Check out
+
+Query: Click on 'Home in Ploubazlanec' +Authorized Xpaths: "{'/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div/span[2]', '/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div', '/html/body/div[5]/d iv/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div', '/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div/span[2]/button/div', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div/div/div/div/div/div[2]/div/div/div/div/a/div', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div', '/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div/span[2]/button[2]', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div/div/div/div/div/div[2]/div/div/div/div', '/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div/div', '/html/body/div[5]/div/div/div/div/div[3]/header/div/div/div/div/div/div[2]/div/div/span[2]/button'}" +Completion: +```yaml +# Let's think through this step-by-step: +# 1. The query asks us to click on 'Home in Ploubazlanec' +# 2. In the HTML, we need to find an element that represents this listing +# 3. We can see a div with the text "Home in Ploubazlanec" in the title +# 4. The parent element of this div is an anchor tag, which is likely the clickable link for the listing +# 5. We should use the XPath of this anchor tag to perform the click action + +- actions: + - action: + # Click on the anchor tag that contains the listing title + args: + xpath: "/html/body/div[5]/div/div/div/div/div[3]/div/main/div[2]/div/div[2]/div/div/div/div/div/div/div/div[2]/div/div/div/div/div/div/div/div/div[2]/div/div/div/div/a" + name: "click" +``` +----- +HTML: +
+
+ +More +
+ + +Authorized Xpaths: "{'/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/a', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div/tab[1]/a', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div/tab[1]', '/html/body/section/devsite-header/div/div[1]/div/div/div[2]'}" +Query: Click on "Gemma" under the "More" dropdown menu. +Completion: +```yaml +# Let's think step by step +# First, we notice that the query asks us to click on the "Gemma" option under the "More" dropdown menu. +# In the provided HTML, we see that the "More" dropdown menu is within a tab element with a specific class and role attribute. +# The "More" dropdown menu can be identified by its class 'devsite-overflow-tab' and contains a link element with the text 'More'. +# We need to interact with this dropdown menu to reveal the hidden options. +# Specifically, for the "More" dropdown menu, there is an anchor element within a tab element: +# /html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/a + +- actions: + - action: + # We can use this XPATH to identify and click on the "More" dropdown menu: + args: + xpath: "/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/a" + value: "" + name: "click" + - action: + # After clicking the "More" dropdown, we need to select the "Gemma" option from the revealed menu. + # The "Gemma" option is located within the dropdown menu and can be identified by its anchor element with the corresponding text: + # /html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div/tab[1]/a + # Thus, we use this XPATH to identify and click on the "Gemma" option: + args: + xpath: "/html/body/section/devsite-header/div/div[1]/div/div/div[2]/div[1]/devsite-tabs/nav/tab[2]/div/tab[1]/a" + value: "" + name: "click" +``` +----- +HTML: + +Authorized Xpaths: "{'/html/body/div/main/form/section/div/select'}" +Query: Select the 2:00 AM - 3:00 AM option from the dropdown menu +Completion: +```yaml +# Let's think step by step +# The query asks us to select the "2:00 AM - 3:00 AM" option from a dropdown menu. +# We need to identify the correct option within the dropdown menu based on its value attribute. +# The dropdown menu is specified by its XPATH, and the value of the option we need to select is "2". +# We can use the following "select" XPATH to locate the dropdown menu and the value "2" to select the appropriate option: +# /html/body/div/main/form/section/div/select + +- actions: + - action: + # Select the "3:00 AM - 4:00 AM" option by targeting the dropdown menu with the specified XPATH. + args: + xpath: "/html/body/div/main/form/section/div/select" + value: "2" + name: "dropdownSelect" +``` +""" diff --git a/lavague-sdk/lavague/sdk/agent.py b/lavague-sdk/lavague/sdk/agent.py index 5201d146..c472e41e 100644 --- a/lavague-sdk/lavague/sdk/agent.py +++ b/lavague-sdk/lavague/sdk/agent.py @@ -1,6 +1,6 @@ from typing import Optional from lavague.sdk.trajectory import Trajectory -from lavague.sdk.client import LaVague +from lavague.sdk.client import LaVague, RunRequest from lavague.sdk.utilities.config import get_config @@ -15,15 +15,26 @@ def __init__( self, api_key: Optional[str] = None, client: Optional[LaVague] = None, + create_public_runs: bool = False, ): if client is None: if api_key is None: api_key = get_config("LAVAGUE_API_KEY") client = LaVague(api_key=api_key) self.client = client + self.create_public_runs = create_public_runs - def run(self, url: str, objective: str, async_run=False) -> Trajectory: - trajectory = self.client.run(url, objective, step_by_step=True) + def run( + self, url: str, objective: str, async_run=False, viewport_size=(1096, 1096) + ) -> Trajectory: + request = RunRequest( + url=url, + objective=objective, + step_by_step=True, + is_public=self.create_public_runs, + viewport_size=viewport_size, + ) + trajectory = self.client.run(request) if not async_run: trajectory.run_to_completion() return trajectory diff --git a/lavague-sdk/lavague/sdk/base_driver.py b/lavague-sdk/lavague/sdk/base_driver.py deleted file mode 100644 index 2c9dfefe..00000000 --- a/lavague-sdk/lavague/sdk/base_driver.py +++ /dev/null @@ -1,747 +0,0 @@ -from PIL import Image -import os -from pathlib import Path -import re -from typing import Any, Callable, Optional, Mapping, Dict, Set, List, Tuple, Union -from abc import ABC, abstractmethod -from enum import Enum -from datetime import datetime -import hashlib - - -class InteractionType(Enum): - CLICK = "click" - HOVER = "hover" - SCROLL = "scroll" - TYPE = "type" - - -PossibleInteractionsByXpath = Dict[str, Set[InteractionType]] - -r_get_xpaths_from_html = r'xpath=["\'](.*?)["\']' - - -class ScrollDirection(Enum): - """Enum for the different scroll directions. Value is (x, y, dimension_index)""" - - LEFT = (-1, 0, 0) - RIGHT = (1, 0, 0) - UP = (0, -1, 1) - DOWN = (0, 1, 1) - - def get_scroll_xy( - self, dimension: List[float], scroll_factor: float = 0.75 - ) -> Tuple[int, int]: - size = dimension[self.value[2]] - return ( - round(self.value[0] * size * scroll_factor), - round(self.value[1] * size * scroll_factor), - ) - - def get_page_script(self, scroll_factor: float = 0.75) -> str: - return f"window.scrollBy({self.value[0] * scroll_factor} * window.innerWidth, {self.value[1] * scroll_factor} * window.innerHeight);" - - def get_script_element_is_scrollable(self) -> str: - match self: - case ScrollDirection.UP: - return "return arguments[0].scrollTop > 0" - case ScrollDirection.DOWN: - return "return arguments[0].scrollTop + arguments[0].clientHeight + 1 < arguments[0].scrollHeight" - case ScrollDirection.LEFT: - return "return arguments[0].scrollLeft > 0" - case ScrollDirection.RIGHT: - return "return arguments[0].scrollLeft + arguments[0].clientWidth + 1 < arguments[0].scrollWidth" - - def get_script_page_is_scrollable(self) -> str: - match self: - case ScrollDirection.UP: - return "return window.scrollY > 0" - case ScrollDirection.DOWN: - return "return window.innerHeight + window.scrollY + 1 < document.body.scrollHeight" - case ScrollDirection.LEFT: - return "return window.scrollX > 0" - case ScrollDirection.RIGHT: - return "return window.innerWidth + window.scrollX + 1 < document.body.scrollWidth" - - @classmethod - def from_string(cls, name: str) -> "ScrollDirection": - return cls[name.upper().strip()] - - -class BaseDriver(ABC): - def __init__(self, url: Optional[str], init_function: Optional[Callable[[], Any]]): - """Init the driver with the init funtion, and then go to the desired url""" - self.init_function = ( - init_function if init_function is not None else self.default_init_code - ) - self.driver = self.init_function() - - # Flag to check if the page has been previously scanned to avoid erasing screenshots from previous scan - self.previously_scanned = False - - if url is not None: - self.get(url) - - async def connect(self) -> None: - """Connect to the driver""" - pass - - @abstractmethod - def default_init_code(self) -> Any: - """Init the driver, with the imports, since it will be pasted to the beginning of the output""" - pass - - @abstractmethod - def destroy(self) -> None: - """Cleanly destroy the underlying driver""" - pass - - @abstractmethod - def get_driver(self) -> Any: - """Return the expected variable name and the driver object""" - pass - - @abstractmethod - def resize_driver(driver, width, height): - """ - Resize the driver to a targeted height and width. - """ - - @abstractmethod - def get_url(self) -> Optional[str]: - """Get the url of the current page""" - pass - - @abstractmethod - def get(self, url: str) -> None: - """Navigate to the url""" - pass - - @abstractmethod - def code_for_get(self, url: str) -> str: - """Return the code to navigate to the url""" - pass - - @abstractmethod - def back(self) -> None: - """Navigate back""" - pass - - @abstractmethod - def maximize_window(self) -> None: - pass - - @abstractmethod - def code_for_back(self) -> None: - """Return driver specific code for going back""" - pass - - @abstractmethod - def get_html(self, clean: bool = True) -> str: - """ - Returns the HTML of the current page. - If clean is True, We remove unnecessary tags and attributes from the HTML. - Clean HTMLs are easier to process for the LLM. - """ - pass - - def get_tabs(self) -> str: - """Return description of the tabs opened with the current tab being focused. - - Example of output: - Tabs opened: - 0 - Overview - OpenAI API - 1 - [CURRENT] Nos destinations Train - SNCF Connect - """ - return "Tabs opened:\n 0 - [CURRENT] tab" - - def switch_tab(self, tab_id: int) -> None: - """Switch to the tab with the given id""" - pass - - def switch_frame(self, xpath) -> None: - """ - switch to the frame pointed at by the xpath - """ - raise NotImplementedError() - - def switch_default_frame(self) -> None: - """ - Switch back to the default frame - """ - raise NotImplementedError() - - def switch_parent_frame(self) -> None: - """ - Switch back to the parent frame - """ - raise NotImplementedError() - - @abstractmethod - def resolve_xpath(self, xpath) -> "DOMNode": - """ - Return the element for the corresponding xpath, the underlying driver may switch iframe if necessary - """ - pass - - def save_screenshot(self, current_screenshot_folder: Path) -> str: - """Save the screenshot data to a file and return the path. If the screenshot already exists, return the path. If not save it to the folder.""" - - new_screenshot = self.get_screenshot_as_png() - hasher = hashlib.md5() - hasher.update(new_screenshot) - new_hash = hasher.hexdigest() - new_screenshot_name = f"{new_hash}.png" - new_screenshot_full_path = current_screenshot_folder / new_screenshot_name - - # If the screenshot does not exist, save it - if not new_screenshot_full_path.exists(): - with open(new_screenshot_full_path, "wb") as f: - f.write(new_screenshot) - return str(new_screenshot_full_path) - - def is_bottom_of_page(self) -> bool: - return self.execute_script( - "return (window.innerHeight + window.scrollY + 1) >= document.body.scrollHeight;" - ) - - def get_screenshots_whole_page(self, max_screenshots=30) -> list[str]: - """Take screenshots of the whole page""" - screenshot_paths = [] - - current_screenshot_folder = self.get_current_screenshot_folder() - - for i in range(max_screenshots): - # Saves a screenshot - screenshot_path = self.save_screenshot(current_screenshot_folder) - screenshot_paths.append(screenshot_path) - self.scroll_down() - self.wait_for_idle() - - if self.is_bottom_of_page(): - break - - self.previously_scanned = True - return screenshot_paths - - @abstractmethod - def get_possible_interactions( - self, - in_viewport=True, - foreground_only=True, - types: List[InteractionType] = [ - InteractionType.CLICK, - InteractionType.TYPE, - InteractionType.HOVER, - ], - ) -> PossibleInteractionsByXpath: - """Get elements that can be interacted with as a dictionary mapped by xpath""" - pass - - def get_in_viewport(self) -> List[str]: - """Get xpath of elements in the viewport""" - return [] - - def check_visibility(self, xpath: str) -> bool: - return True - - @abstractmethod - def get_viewport_size(self) -> dict: - """Return viewport size as {"width": int, "height": int}""" - pass - - @abstractmethod - def get_highlighted_element(self, generated_code: str): - """Return the page elements that generated code interact with""" - pass - - @abstractmethod - def exec_code( - self, - code: str, - globals: dict[str, Any] = None, - locals: Mapping[str, object] = None, - ): - """Exec generated code""" - pass - - @abstractmethod - def execute_script(self, js_code: str, *args) -> Any: - """Exec js script in DOM""" - pass - - @abstractmethod - def scroll( - self, - xpath_anchor: Optional[str], - direction: ScrollDirection, - scroll_factor=0.75, - ): - pass - - # TODO: Remove these methods as they are not used - @abstractmethod - def scroll_up(self): - pass - - @abstractmethod - def scroll_down(self): - pass - - @abstractmethod - def code_for_execute_script(self, js_code: str): - """return driver specific code to execute js script in DOM""" - pass - - @abstractmethod - def get_capability(self) -> str: - """Prompt to explain the llm which style of code he should output and which variables and imports he should expect""" - pass - - def get_obs(self) -> dict: - """Get the current observation of the driver""" - current_screenshot_folder = self.get_current_screenshot_folder() - - if not self.previously_scanned: - # If the last operation was not to scan the whole page, we clear the screenshot folder - try: - if os.path.isdir(current_screenshot_folder): - for filename in os.listdir(current_screenshot_folder): - file_path = os.path.join(current_screenshot_folder, filename) - try: - # Check if it's a file and then delete it - if os.path.isfile(file_path) or os.path.islink(file_path): - os.remove(file_path) - except Exception as e: - print(f"Failed to delete {file_path}. Reason: {e}") - - except Exception as e: - raise Exception(f"Error while clearing screenshot folder: {e}") - else: - # If the last operation was to scan the whole page, we reset the flag - self.previously_scanned = False - - # We add labels to the scrollable elements - i_scroll = self.get_possible_interactions(types=[InteractionType.SCROLL]) - scrollables_xpaths = list(i_scroll.keys()) - - self.remove_highlight() - self.highlight_nodes(scrollables_xpaths, label=True) - - # We take a screenshot and computes its hash to see if it already exists - self.save_screenshot(current_screenshot_folder) - self.remove_highlight() - - url = self.get_url() - html = self.get_html() - obs = { - "html": html, - "screenshots_path": str(current_screenshot_folder), - "url": url, - "date": datetime.now().isoformat(), - "tab_info": self.get_tabs(), - } - - return obs - - def wait(self, duration): - import time - - time.sleep(duration) - - def wait_for_idle(self): - pass - - def get_current_screenshot_folder(self) -> Path: - url = self.get_url() - - if url is None: - url = "blank" - - screenshots_path = Path("./screenshots") - screenshots_path.mkdir(exist_ok=True) - - current_url = url.replace("://", "_").replace("/", "_") - hasher = hashlib.md5() - hasher.update(current_url.encode("utf-8")) - - current_screenshot_folder = screenshots_path / hasher.hexdigest() - current_screenshot_folder.mkdir(exist_ok=True) - return current_screenshot_folder - - @abstractmethod - def get_screenshot_as_png(self) -> bytes: - pass - - @abstractmethod - def get_shadow_roots(self) -> Dict[str, str]: - pass - - def get_nodes(self, xpaths: List[str]) -> List["DOMNode"]: - raise NotImplementedError("get_nodes not implemented") - - def get_nodes_from_html(self, html: str) -> List["DOMNode"]: - return self.get_nodes(re.findall(r_get_xpaths_from_html, html)) - - def highlight_node_from_xpath( - self, xpath: str, color: str = "red", label=False - ) -> Callable: - return self.highlight_nodes([xpath], color, label) - - def highlight_nodes( - self, xpaths: List[str], color: str = "red", label=False - ) -> Callable: - nodes = self.get_nodes(xpaths) - for n in nodes: - n.highlight(color) - return self._add_highlighted_destructors(lambda: [n.clear() for n in nodes]) - - def highlight_nodes_from_html( - self, html: str, color: str = "blue", label=False - ) -> Callable: - return self.highlight_nodes( - re.findall(r_get_xpaths_from_html, html), color, label - ) - - def remove_highlight(self): - if hasattr(self, "_highlight_destructors"): - for destructor in self._highlight_destructors: - destructor() - delattr(self, "_highlight_destructors") - - def _add_highlighted_destructors( - self, destructors: Union[List[Callable], Callable] - ) -> Callable: - if not hasattr(self, "_highlight_destructors"): - self._highlight_destructors = [] - if isinstance(destructors, Callable): - self._highlight_destructors.append(destructors) - return destructors - - self._highlight_destructors.extend(destructors) - return lambda: [d() for d in destructors] - - def highlight_interactive_nodes( - self, - *with_interactions: tuple[InteractionType], - color: str = "red", - in_viewport=True, - foreground_only=True, - label=False, - ): - if with_interactions is None or len(with_interactions) == 0: - return self.highlight_nodes( - list( - self.get_possible_interactions( - in_viewport=in_viewport, foreground_only=foreground_only - ).keys() - ), - color, - label, - ) - - return self.highlight_nodes( - [ - xpath - for xpath, interactions in self.get_possible_interactions( - in_viewport=in_viewport, foreground_only=foreground_only - ).items() - if set(interactions) & set(with_interactions) - ], - color, - label, - ) - - -class DOMNode(ABC): - @property - @abstractmethod - def element(self) -> Any: - pass - - @property - @abstractmethod - def value(self) -> Any: - pass - - @abstractmethod - def highlight(self, color: str = "red", bounding_box=True): - pass - - @abstractmethod - def clear(self): - return self - - @abstractmethod - def take_screenshot(self) -> Image.Image: - pass - - @abstractmethod - def get_html(self) -> str: - pass - - def __str__(self) -> str: - return self.get_html() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - -def js_wrap_function_call(fn: str): - return "(function(){" + fn + "})()" - - -JS_SETUP_GET_EVENTS = """ -(function() { - if (window && !window.getEventListeners) { - const targetProto = EventTarget.prototype; - targetProto._addEventListener = Element.prototype.addEventListener; - targetProto.addEventListener = function(a,b,c) { - this._addEventListener(a,b,c); - if(!this.eventListenerList) this.eventListenerList = {}; - if(!this.eventListenerList[a]) this.eventListenerList[a] = []; - this.eventListenerList[a].push(b); - }; - targetProto._removeEventListener = Element.prototype.removeEventListener; - targetProto.removeEventListener = function(a, b, c) { - this._removeEventListener(a, b, c); - if(this.eventListenerList && this.eventListenerList[a]) { - const index = this.eventListenerList[a].indexOf(b); - if (index > -1) { - this.eventListenerList[a].splice(index, 1); - if (!this.eventListenerList[a].length) { - delete this.eventListenerList[a]; - } - } - } - }; - window.getEventListeners = function(e) { - return (e && e.eventListenerList) || []; - } - } -})();""" - -JS_GET_INTERACTIVES = """ -const windowHeight = (window.innerHeight || document.documentElement.clientHeight); -const windowWidth = (window.innerWidth || document.documentElement.clientWidth); - -return (function(inViewport, foregroundOnly, nonInteractives, filterTypes) { - function getInteractions(e) { - const tag = e.tagName.toLowerCase(); - if (!e.checkVisibility() || (e.hasAttribute('disabled') && !nonInteractives) || e.hasAttribute('readonly') - || (tag === 'input' && e.getAttribute('type') === 'hidden') || tag === 'body') { - return []; - } - const rect = e.getBoundingClientRect(); - if (rect.width + rect.height < 5) { - return []; - } - const style = getComputedStyle(e) || {}; - if (style.display === 'none' || style.visibility === 'hidden') { - return []; - } - const events = window && typeof window.getEventListeners === 'function' ? window.getEventListeners(e) : []; - const role = e.getAttribute('role'); - const clickableInputs = ['submit', 'checkbox', 'radio', 'color', 'file', 'image', 'reset']; - function hasEvent(n) { - return events[n]?.length || e.hasAttribute('on' + n); - } - let evts = []; - if (hasEvent('keydown') || hasEvent('keyup') || hasEvent('keypress') || hasEvent('keydown') || hasEvent('input') || e.isContentEditable - || ( - (tag === 'input' || tag === 'textarea' || role === 'searchbox' || role === 'input') - ) && !clickableInputs.includes(e.getAttribute('type')) - ) { - evts.push('TYPE'); - } - if (['a', 'button', 'select'].includes(tag) || ['button', 'checkbox', 'select'].includes(role) - || hasEvent('click') || hasEvent('mousedown') || hasEvent('mouseup') || hasEvent('dblclick') - || style.cursor === 'pointer' - || e.hasAttribute('aria-haspopup') - || (tag === 'input' && clickableInputs.includes(e.getAttribute('type'))) - || (tag === 'label' && document.getElementById(e.getAttribute('for'))) - ) { - evts.push('CLICK'); - } - if ( - (hasEvent('scroll') || hasEvent('wheel') || style.overflow === 'auto' || style.overflow === 'scroll' || style.overflowY === 'auto' || style.overflowY === 'scroll') - && (e.scrollHeight > e.clientHeight || e.scrollWidth > e.clientWidth)) { - evts.push('SCROLL'); - } - if (filterTypes && evts.length) { - evts = evts.filter(t => filterTypes.includes(t)); - } - if (nonInteractives && evts.length === 0) { - evts.push('NONE'); - } - - if (inViewport) { - let rect = e.getBoundingClientRect(); - let iframe = e.ownerDocument.defaultView.frameElement; - while (iframe) { - const iframeRect = iframe.getBoundingClientRect(); - rect = { - top: rect.top + iframeRect.top, - left: rect.left + iframeRect.left, - bottom: rect.bottom + iframeRect.bottom, - right: rect.right + iframeRect.right, - width: rect.width, - height: rect.height - } - iframe = iframe.ownerDocument.defaultView.frameElement; - } - const elemCenter = { - x: Math.round(rect.left + rect.width / 2), - y: Math.round(rect.top + rect.height / 2) - }; - if (elemCenter.x < 0) return []; - if (elemCenter.x > windowWidth) return []; - if (elemCenter.y < 0) return []; - if (elemCenter.y > windowHeight) return []; - if (!foregroundOnly) return evts; // whenever to check for elements above - let pointContainer = document.elementFromPoint(elemCenter.x, elemCenter.y); - iframe = e.ownerDocument.defaultView.frameElement; - while (iframe) { - const iframeRect = iframe.getBoundingClientRect(); - pointContainer = iframe.contentDocument.elementFromPoint( - elemCenter.x - iframeRect.left, - elemCenter.y - iframeRect.top - ); - iframe = iframe.ownerDocument.defaultView.frameElement; - } - do { - if (pointContainer === e) return evts; - if (pointContainer == null) return evts; - } while (pointContainer = pointContainer.parentNode); - return []; - } - return evts; - } - - const results = {}; - function traverse(node, xpath) { - if (node.nodeType === Node.ELEMENT_NODE) { - const interactions = getInteractions(node); - if (interactions.length > 0) { - results[xpath] = interactions; - } - } - const countByTag = {}; - for (let child = node.firstChild; child; child = child.nextSibling) { - let tag = child.nodeName.toLowerCase(); - if (tag.includes(":")) continue; //namespace - let isLocal = ['svg'].includes(tag); - if (isLocal) { - tag = `*[local-name() = '${tag}']`; - } - countByTag[tag] = (countByTag[tag] || 0) + 1; - let childXpath = xpath + '/' + tag; - if (countByTag[tag] > 1) { - childXpath += '[' + countByTag[tag] + ']'; - } - if (tag === 'iframe') { - try { - traverse(child.contentWindow.document.body, childXpath + '/html/body'); - } catch (e) { - console.warn("iframe access blocked", child, e); - } - } else if (!isLocal) { - traverse(child, childXpath); - if (child.shadowRoot) { - traverse(child.shadowRoot, childXpath + '/'); - } - } - } - } - traverse(document.body, '/html/body'); - return results; -})(arguments?.[0], arguments?.[1], arguments?.[2], arguments?.[3]); -""" - -JS_WAIT_DOM_IDLE = """ -return new Promise(resolve => { - const timeout = arguments[0] || 10000; - const stabilityThreshold = arguments[1] || 100; - - let mutationObserver; - let timeoutId = null; - - const waitForIdle = () => { - if (timeoutId) clearTimeout(timeoutId); - timeoutId = setTimeout(() => resolve(true), stabilityThreshold); - }; - mutationObserver = new MutationObserver(waitForIdle); - mutationObserver.observe(document.body, { - childList: true, - attributes: true, - subtree: true, - }); - waitForIdle(); - - setTimeout(() => { - resolve(false); - mutationObserver.disconnect(); - mutationObserver = null; - if (timeoutId) { - clearTimeout(timeoutId); - timeoutId = null; - } - }, timeout); -}); -""" - -JS_GET_SCROLLABLE_PARENT = """ -let element = arguments[0]; -while (element) { - const style = window.getComputedStyle(element); - - // Check if the element is scrollable - if (style.overflow === 'auto' || style.overflow === 'scroll' || - style.overflowX === 'auto' || style.overflowX === 'scroll' || - style.overflowY === 'auto' || style.overflowY === 'scroll') { - - // Check if the element has a scrollable area - if (element.scrollHeight > element.clientHeight || - element.scrollWidth > element.clientWidth) { - return element; - } - } - element = element.parentElement; -} -return null; -""" - -JS_GET_SHADOW_ROOTS = """ -const results = {}; -function traverse(node, xpath) { - if (node.shadowRoot) { - results[xpath] = node.shadowRoot.getHTML(); - } - const countByTag = {}; - for (let child = node.firstChild; child; child = child.nextSibling) { - let tag = child.nodeName.toLowerCase(); - countByTag[tag] = (countByTag[tag] || 0) + 1; - let childXpath = xpath + '/' + tag; - if (countByTag[tag] > 1) { - childXpath += '[' + countByTag[tag] + ']'; - } - if (child.shadowRoot) { - traverse(child.shadowRoot, childXpath + '/'); - } - if (tag === 'iframe') { - try { - traverse(child.contentWindow.document.body, childXpath + '/html/body'); - } catch (e) { - console.warn("iframe access blocked", child, e); - } - } else { - traverse(child, childXpath); - } - } -} -traverse(document.body, '/html/body'); -return results; -""" diff --git a/lavague-sdk/lavague/sdk/base_driver/__init__.py b/lavague-sdk/lavague/sdk/base_driver/__init__.py new file mode 100644 index 00000000..b036ad9e --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/__init__.py @@ -0,0 +1 @@ +from lavague.sdk.base_driver.base import BaseDriver, DOMNode, DriverObservation diff --git a/lavague-sdk/lavague/sdk/base_driver/base.py b/lavague-sdk/lavague/sdk/base_driver/base.py new file mode 100644 index 00000000..6571edb8 --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/base.py @@ -0,0 +1,294 @@ +import re +from abc import ABC, abstractmethod +from contextlib import contextmanager +from datetime import datetime +from typing import Callable, Dict, List, Optional, Union, TypeVar, Generic +from pydantic import BaseModel +from lavague.sdk.action.navigation import NavigationOutput + +from lavague.sdk.base_driver.interaction import ( + InteractionType, + PossibleInteractionsByXpath, + ScrollDirection, +) +from lavague.sdk.base_driver.node import DOMNode + + +class DriverObservation(BaseModel): + html: str + screenshot: bytes + url: str + date: str + tab_info: str + + +T = TypeVar("T", bound=DOMNode, covariant=True) + + +class BaseDriver(ABC, Generic[T]): + @abstractmethod + def init(self) -> None: + """Init the underlying driver""" + pass + + def execute(self, action: NavigationOutput) -> None: + """Execute an action""" + with self.resolve_xpath(action.xpath) as node: + match action.navigation_command: + case InteractionType.CLICK: + node.click() + + case InteractionType.TYPE: + node.set_value(action.value or "") + + case InteractionType.HOVER: + node.hover() + + case InteractionType.SCROLL: + direction = ScrollDirection.from_string(action.value or "DOWN") + self.scroll(action.xpath, direction) + + case _: + raise NotImplementedError( + f"Action {action.navigation_command} not implemented" + ) + + @abstractmethod + def destroy(self) -> None: + """Cleanly destroy the underlying driver""" + pass + + @abstractmethod + def resize_driver(self, width: int, height: int): + """Resize the viewport to a targeted height and width""" + + @abstractmethod + def get_url(self) -> str: + """Get the url of the current page, raise NoPageException if no page is loaded""" + pass + + @abstractmethod + def get(self, url: str) -> None: + """Navigate to the url""" + pass + + @abstractmethod + def back(self) -> None: + """Navigate back, raise CannotBackException if history root is reached""" + pass + + @abstractmethod + def get_html(self) -> str: + """ + Returns the HTML of the current page. + If clean is True, We remove unnecessary tags and attributes from the HTML. + Clean HTMLs are easier to process for the LLM. + """ + pass + + @abstractmethod + def get_tabs(self) -> str: + """Return description of the tabs opened with the current tab being focused. + + Example of output: + Tabs opened: + 0 - Overview - OpenAI API + 1 - [CURRENT] Nos destinations Train - SNCF Connect + """ + pass + + @abstractmethod + def switch_tab(self, tab_id: int) -> None: + """Switch to the tab with the given id""" + pass + + @abstractmethod + def resolve_xpath(self, xpath: str) -> T: + """ + Return the element for the corresponding xpath, the underlying driver may switch iframe if necessary + """ + pass + + @abstractmethod + def get_viewport_size(self) -> dict: + """Return viewport size as {"width": int, "height": int}""" + pass + + @abstractmethod + def get_possible_interactions( + self, + in_viewport=True, + foreground_only=True, + types: List[InteractionType] = [ + InteractionType.CLICK, + InteractionType.TYPE, + InteractionType.HOVER, + ], + ) -> PossibleInteractionsByXpath: + """Get elements that can be interacted with as a dictionary mapped by xpath""" + pass + + @abstractmethod + def scroll( + self, + xpath_anchor: Optional[str] = "/html/body", + direction: ScrollDirection = ScrollDirection.DOWN, + scroll_factor=0.75, + ): + pass + + @abstractmethod + def scroll_into_view(self, xpath: str): + pass + + @abstractmethod + def wait_for_idle(self): + pass + + @abstractmethod + def get_capability(self) -> str: + """Prompt to explain the llm which style of code he should output and which variables and imports he should expect""" + pass + + @abstractmethod + def get_screenshot_as_png(self) -> bytes: + pass + + @abstractmethod + def get_shadow_roots(self) -> Dict[str, str]: + """Return a dictionary of shadow roots HTML by xpath""" + pass + + @abstractmethod + def get_nodes(self, xpaths: List[str]) -> List[T]: + pass + + @abstractmethod + def highlight_nodes( + self, xpaths: List[str], color: str = "red", label=False + ) -> Callable: + pass + + @abstractmethod + def switch_frame(self, xpath: str) -> None: + """Switch to the frame with the given xpath, use with care as it changes the state of the driver""" + pass + + @abstractmethod + def switch_parent_frame(self) -> None: + """Switch to the parent frame, use with care as it changes the state of the driver""" + pass + + @contextmanager + def nodes_highlighter(self, nodes: List[str], color: str = "red", label=False): + """Highlight nodes for a context manager""" + remove_highlight = self.highlight_nodes(nodes, color, label) + yield + remove_highlight() + + def get_obs(self) -> DriverObservation: + """Get the current observation of the driver""" + + # We add labels to the scrollable elements + scrollables = self.get_scroll_containers() + with self.nodes_highlighter(scrollables, label=True): + screenshot = self.get_screenshot_as_png() + + url = self.get_url() + html = self.get_html() + tab_info = self.get_tabs() + + return DriverObservation( + html=html, + screenshot=screenshot, + url=url, + date=datetime.now().isoformat(), + tab_info=tab_info, + ) + + def get_in_viewport(self) -> List[str]: + """Get xpath of elements in the viewport""" + interactions = self.get_possible_interactions(in_viewport=True, types=[]) + return list(interactions.keys()) + + def get_scroll_containers(self) -> List[str]: + """Get xpath of elements in the viewport""" + interactions = self.get_possible_interactions(types=[InteractionType.SCROLL]) + return list(interactions.keys()) + + def get_nodes_from_html(self, html: str) -> List[T]: + return self.get_nodes(re.findall(r"xpath=[\"'](.*?)[\"']", html)) + + def highlight_node_from_xpath( + self, xpath: str, color: str = "red", label=False + ) -> Callable: + return self.highlight_nodes([xpath], color, label) + + def highlight_nodes_from_html( + self, html: str, color: str = "blue", label=False + ) -> Callable: + return self.highlight_nodes( + re.findall(r"xpath=[\"'](.*?)[\"']", html), color, label + ) + + def remove_highlight(self): + if hasattr(self, "_highlight_destructors"): + for destructor in self._highlight_destructors: + destructor() + delattr(self, "_highlight_destructors") + + def _add_highlighted_destructors( + self, destructors: Union[List[Callable], Callable] + ) -> Callable: + if not hasattr(self, "_highlight_destructors"): + self._highlight_destructors = [] + if isinstance(destructors, Callable): + self._highlight_destructors.append(destructors) + return destructors + + self._highlight_destructors.extend(destructors) + return lambda: [d() for d in destructors] + + def highlight_interactive_nodes( + self, + *with_interactions: tuple[InteractionType], + color: str = "red", + in_viewport=True, + foreground_only=True, + label=False, + ): + if with_interactions is None or len(with_interactions) == 0: + return self.highlight_nodes( + list( + self.get_possible_interactions( + in_viewport=in_viewport, foreground_only=foreground_only + ).keys() + ), + color, + label, + ) + + return self.highlight_nodes( + [ + xpath + for xpath, interactions in self.get_possible_interactions( + in_viewport=in_viewport, foreground_only=foreground_only + ).items() + if set(interactions) & set(with_interactions) + ], + color, + label, + ) + + def __enter__(self): + self.init() + self.driver_ready = True + return self + + def __exit__(self, *_): + self.destroy() + self.driver_ready = False + + def __del__(self): + if self.driver_ready: + self.__exit__() diff --git a/lavague-sdk/lavague/sdk/base_driver/interaction.py b/lavague-sdk/lavague/sdk/base_driver/interaction.py new file mode 100644 index 00000000..4024c735 --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/interaction.py @@ -0,0 +1,59 @@ +from typing import Dict, Set, List, Tuple +from enum import Enum + + +class InteractionType(Enum): + CLICK = "click" + HOVER = "hover" + SCROLL = "scroll" + TYPE = "type" + + +PossibleInteractionsByXpath = Dict[str, Set[InteractionType]] + + +class ScrollDirection(Enum): + """Enum for the different scroll directions. Value is (x, y, dimension_index)""" + + LEFT = (-1, 0, 0) + RIGHT = (1, 0, 0) + UP = (0, -1, 1) + DOWN = (0, 1, 1) + + def get_scroll_xy( + self, dimension: List[float], scroll_factor: float = 0.75 + ) -> Tuple[int, int]: + size = dimension[self.value[2]] + return ( + round(self.value[0] * size * scroll_factor), + round(self.value[1] * size * scroll_factor), + ) + + def get_page_script(self, scroll_factor: float = 0.75) -> str: + return f"window.scrollBy({self.value[0] * scroll_factor} * window.innerWidth, {self.value[1] * scroll_factor} * window.innerHeight);" + + def get_script_element_is_scrollable(self) -> str: + match self: + case ScrollDirection.UP: + return "return arguments[0].scrollTop > 0" + case ScrollDirection.DOWN: + return "return arguments[0].scrollTop + arguments[0].clientHeight + 1 < arguments[0].scrollHeight" + case ScrollDirection.LEFT: + return "return arguments[0].scrollLeft > 0" + case ScrollDirection.RIGHT: + return "return arguments[0].scrollLeft + arguments[0].clientWidth + 1 < arguments[0].scrollWidth" + + def get_script_page_is_scrollable(self) -> str: + match self: + case ScrollDirection.UP: + return "return window.scrollY > 0" + case ScrollDirection.DOWN: + return "return window.innerHeight + window.scrollY + 1 < document.body.scrollHeight" + case ScrollDirection.LEFT: + return "return window.scrollX > 0" + case ScrollDirection.RIGHT: + return "return window.innerWidth + window.scrollX + 1 < document.body.scrollWidth" + + @classmethod + def from_string(cls, name: str) -> "ScrollDirection": + return cls[name.upper().strip()] diff --git a/lavague-sdk/lavague/sdk/base_driver/javascript.py b/lavague-sdk/lavague/sdk/base_driver/javascript.py new file mode 100644 index 00000000..a009300f --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/javascript.py @@ -0,0 +1,319 @@ +def js_wrap_function_call(fn: str): + return "(function(){" + fn + "})()" + + +JS_SETUP_GET_EVENTS = """ +(function() { + if (window && !window.getEventListeners) { + const targetProto = EventTarget.prototype; + targetProto._addEventListener = Element.prototype.addEventListener; + targetProto.addEventListener = function(a,b,c) { + this._addEventListener(a,b,c); + if(!this.eventListenerList) this.eventListenerList = {}; + if(!this.eventListenerList[a]) this.eventListenerList[a] = []; + this.eventListenerList[a].push(b); + }; + targetProto._removeEventListener = Element.prototype.removeEventListener; + targetProto.removeEventListener = function(a, b, c) { + this._removeEventListener(a, b, c); + if(this.eventListenerList && this.eventListenerList[a]) { + const index = this.eventListenerList[a].indexOf(b); + if (index > -1) { + this.eventListenerList[a].splice(index, 1); + if (!this.eventListenerList[a].length) { + delete this.eventListenerList[a]; + } + } + } + }; + window.getEventListeners = function(e) { + return (e && e.eventListenerList) || []; + } + } +})();""" + +JS_GET_INTERACTIVES = """ +const windowHeight = (window.innerHeight || document.documentElement.clientHeight); +const windowWidth = (window.innerWidth || document.documentElement.clientWidth); + +return (function(inViewport, foregroundOnly, nonInteractives, filterTypes) { + function getInteractions(e) { + const tag = e.tagName.toLowerCase(); + if (!e.checkVisibility() || (e.hasAttribute('disabled') && !nonInteractives) || e.hasAttribute('readonly') + || (tag === 'input' && e.getAttribute('type') === 'hidden') || tag === 'body') { + return []; + } + const rect = e.getBoundingClientRect(); + if (rect.width + rect.height < 5) { + return []; + } + const style = getComputedStyle(e) || {}; + if (style.display === 'none' || style.visibility === 'hidden') { + return []; + } + const events = window && typeof window.getEventListeners === 'function' ? window.getEventListeners(e) : []; + const role = e.getAttribute('role'); + const clickableInputs = ['submit', 'checkbox', 'radio', 'color', 'file', 'image', 'reset']; + function hasEvent(n) { + return events[n]?.length || e.hasAttribute('on' + n); + } + let evts = []; + if (hasEvent('keydown') || hasEvent('keyup') || hasEvent('keypress') || hasEvent('keydown') || hasEvent('input') || e.isContentEditable + || ( + (tag === 'input' || tag === 'textarea' || role === 'searchbox' || role === 'input') + ) && !clickableInputs.includes(e.getAttribute('type')) + ) { + evts.push('TYPE'); + } + if (['a', 'button', 'select'].includes(tag) || ['button', 'checkbox', 'select'].includes(role) + || hasEvent('click') || hasEvent('mousedown') || hasEvent('mouseup') || hasEvent('dblclick') + || style.cursor === 'pointer' + || e.hasAttribute('aria-haspopup') + || (tag === 'input' && clickableInputs.includes(e.getAttribute('type'))) + || (tag === 'label' && document.getElementById(e.getAttribute('for'))) + ) { + evts.push('CLICK'); + } + if ( + (hasEvent('scroll') || hasEvent('wheel') || style.overflow === 'auto' || style.overflow === 'scroll' || style.overflowY === 'auto' || style.overflowY === 'scroll') + && (e.scrollHeight > e.clientHeight || e.scrollWidth > e.clientWidth)) { + evts.push('SCROLL'); + } + if (filterTypes && filterTypes.length) { + evts = evts.filter(t => filterTypes.includes(t)); + } + if (nonInteractives && evts.length === 0) { + evts.push('NONE'); + } + + if (inViewport) { + let rect = e.getBoundingClientRect(); + let iframe = e.ownerDocument.defaultView.frameElement; + while (iframe) { + const iframeRect = iframe.getBoundingClientRect(); + rect = { + top: rect.top + iframeRect.top, + left: rect.left + iframeRect.left, + bottom: rect.bottom + iframeRect.bottom, + right: rect.right + iframeRect.right, + width: rect.width, + height: rect.height + } + iframe = iframe.ownerDocument.defaultView.frameElement; + } + const elemCenter = { + x: Math.round(rect.left + rect.width / 2), + y: Math.round(rect.top + rect.height / 2) + }; + if (elemCenter.x < 0) return []; + if (elemCenter.x > windowWidth) return []; + if (elemCenter.y < 0) return []; + if (elemCenter.y > windowHeight) return []; + if (!foregroundOnly) return evts; // whenever to check for elements above + let pointContainer = document.elementFromPoint(elemCenter.x, elemCenter.y); + iframe = e.ownerDocument.defaultView.frameElement; + while (iframe) { + const iframeRect = iframe.getBoundingClientRect(); + pointContainer = iframe.contentDocument.elementFromPoint( + elemCenter.x - iframeRect.left, + elemCenter.y - iframeRect.top + ); + iframe = iframe.ownerDocument.defaultView.frameElement; + } + do { + if (pointContainer === e) return evts; + if (pointContainer == null) return evts; + } while (pointContainer = pointContainer.parentNode); + return []; + } + return evts; + } + + const results = {}; + function traverse(node, xpath) { + if (node.nodeType === Node.ELEMENT_NODE) { + const interactions = getInteractions(node); + if (interactions.length > 0) { + results[xpath] = interactions; + } + } + const countByTag = {}; + for (let child = node.firstChild; child; child = child.nextSibling) { + let tag = child.nodeName.toLowerCase(); + if (tag.includes(":")) continue; //namespace + let isLocal = ['svg'].includes(tag); + if (isLocal) { + tag = `*[local-name() = '${tag}']`; + } + countByTag[tag] = (countByTag[tag] || 0) + 1; + let childXpath = xpath + '/' + tag; + if (countByTag[tag] > 1) { + childXpath += '[' + countByTag[tag] + ']'; + } + if (tag === 'iframe') { + try { + traverse(child.contentWindow.document.body, childXpath + '/html/body'); + } catch (e) { + console.warn("iframe access blocked", child, e); + } + } else if (!isLocal) { + traverse(child, childXpath); + if (child.shadowRoot) { + traverse(child.shadowRoot, childXpath + '/'); + } + } + } + } + traverse(document.body, '/html/body'); + return results; +})(arguments?.[0], arguments?.[1], arguments?.[2], arguments?.[3]); +""" + +JS_WAIT_DOM_IDLE = """ +return new Promise(resolve => { + const timeout = arguments[0] || 10000; + const stabilityThreshold = arguments[1] || 100; + + let mutationObserver; + let timeoutId = null; + + const waitForIdle = () => { + if (timeoutId) clearTimeout(timeoutId); + timeoutId = setTimeout(() => resolve(true), stabilityThreshold); + }; + mutationObserver = new MutationObserver(waitForIdle); + mutationObserver.observe(document.body, { + childList: true, + attributes: true, + subtree: true, + }); + waitForIdle(); + + setTimeout(() => { + resolve(false); + mutationObserver.disconnect(); + mutationObserver = null; + if (timeoutId) { + clearTimeout(timeoutId); + timeoutId = null; + } + }, timeout); +}); +""" + +JS_GET_SCROLLABLE_PARENT = """ +let element = arguments[0]; +while (element) { + const style = window.getComputedStyle(element); + + // Check if the element is scrollable + if (style.overflow === 'auto' || style.overflow === 'scroll' || + style.overflowX === 'auto' || style.overflowX === 'scroll' || + style.overflowY === 'auto' || style.overflowY === 'scroll') { + + // Check if the element has a scrollable area + if (element.scrollHeight > element.clientHeight || + element.scrollWidth > element.clientWidth) { + return element; + } + } + element = element.parentElement; +} +return null; +""" + +JS_GET_SHADOW_ROOTS = """ +const results = {}; +function traverse(node, xpath) { + if (node.shadowRoot) { + results[xpath] = node.shadowRoot.getHTML(); + } + const countByTag = {}; + for (let child = node.firstChild; child; child = child.nextSibling) { + let tag = child.nodeName.toLowerCase(); + countByTag[tag] = (countByTag[tag] || 0) + 1; + let childXpath = xpath + '/' + tag; + if (countByTag[tag] > 1) { + childXpath += '[' + countByTag[tag] + ']'; + } + if (child.shadowRoot) { + traverse(child.shadowRoot, childXpath + '/'); + } + if (tag === 'iframe') { + try { + traverse(child.contentWindow.document.body, childXpath + '/html/body'); + } catch (e) { + console.warn("iframe access blocked", child, e); + } + } else { + traverse(child, childXpath); + } + } +} +traverse(document.body, '/html/body'); +return results; +""" + +ATTACH_MOVE_LISTENER = """ +if (!window._lavague_move_listener) { + window._lavague_move_listener = function() { + const bbs = document.querySelectorAll('.lavague-highlight'); + bbs.forEach(bb => { + const rect = bb._tracking.getBoundingClientRect(); + bb.style.top = rect.top + 'px'; + bb.style.left = rect.left + 'px'; + bb.style.width = rect.width + 'px'; + bb.style.height = rect.height + 'px'; + }); + }; + window.addEventListener('scroll', window._lavague_move_listener); + window.addEventListener('resize', window._lavague_move_listener); +} +""" + +REMOVE_HIGHLIGHT = """ +if (window._lavague_move_listener) { + window.removeEventListener('scroll', window._lavague_move_listener); + window.removeEventListener('resize', window._lavague_move_listener); + delete window._lavague_move_listener; +} +arguments[0].filter(a => a).forEach(a => a.style.cssText = a.dataset.originalStyle || ''); +document.querySelectorAll('.lavague-highlight').forEach(a => a.remove()); +""" + + +def get_highlighter_style(color: str = "red", label: bool = False): + set_style = f""" + const r = a.getBoundingClientRect(); + const bb = document.createElement('div'); + const s = window.getComputedStyle(a); + bb.className = 'lavague-highlight'; + bb.style.position = 'fixed'; + bb.style.top = r.top + 'px'; + bb.style.left = r.left + 'px'; + bb.style.width = r.width + 'px'; + bb.style.height = r.height + 'px'; + bb.style.border = '3px solid {color}'; + bb.style.borderRadius = s.borderRadius; + bb.style['z-index'] = '2147483647'; + bb.style['pointer-events'] = 'none'; + bb._tracking = a; + document.body.appendChild(bb); + """ + + if label: + set_style += """ + const label = document.createElement('div'); + label.style.position = 'absolute'; + label.style.backgroundColor = 'red'; + label.style.color = 'white'; + label.style.padding = '0px 6px 2px 4px'; + label.style.top = '-12px'; + label.style.left = '-12px'; + label.style['font-size'] = '13pt'; + label.style['font-weight'] = 'bold'; + label.style['border-bottom-right-radius'] = '13px'; + label.textContent = i + 1; + bb.appendChild(label); + """ + return set_style diff --git a/lavague-sdk/lavague/sdk/base_driver/node.py b/lavague-sdk/lavague/sdk/base_driver/node.py new file mode 100644 index 00000000..0935ea9f --- /dev/null +++ b/lavague-sdk/lavague/sdk/base_driver/node.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from typing import Generic, Optional, TypeVar + +from PIL import Image + + +T = TypeVar("T") + + +class DOMNode(ABC, Generic[T]): + @property + @abstractmethod + def element(self) -> T: + pass + + @property + @abstractmethod + def text(self) -> str: + pass + + @property + @abstractmethod + def value(self) -> Optional[str]: + pass + + @property + @abstractmethod + def outer_html(self) -> str: + pass + + @property + @abstractmethod + def inner_html(self) -> str: + pass + + @abstractmethod + def take_screenshot(self) -> Image.Image: + pass + + @abstractmethod + def click(self): + pass + + @abstractmethod + def hover(self): + pass + + @abstractmethod + def set_value(self, value: str): + pass + + @abstractmethod + def enter_context(self): + pass + + @abstractmethod + def exit_context(self): + pass + + def __str__(self) -> str: + with self: + return self.outer_html + + def __enter__(self): + self.enter_context() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_context() diff --git a/lavague-sdk/lavague/sdk/client.py b/lavague-sdk/lavague/sdk/client.py index 64688abe..b5bdf143 100644 --- a/lavague-sdk/lavague/sdk/client.py +++ b/lavague-sdk/lavague/sdk/client.py @@ -1,12 +1,29 @@ -from lavague.sdk.trajectory.model import StepCompletion -from lavague.sdk.utilities.config import get_config, is_flag_true, LAVAGUE_API_BASE_URL -from lavague.sdk.action import ActionParser, DEFAULT_PARSER +from io import BytesIO +from typing import Any, Optional, Tuple + +import requests +from lavague.sdk.action import DEFAULT_PARSER, ActionParser from lavague.sdk.trajectory import Trajectory from lavague.sdk.trajectory.controller import TrajectoryController -from typing import Any, Optional +from lavague.sdk.trajectory.model import StepCompletion +from lavague.sdk.utilities.config import LAVAGUE_API_BASE_URL, get_config, is_flag_true from PIL import Image, ImageFile -from io import BytesIO -import requests +from pydantic import BaseModel + + +class RunRequest(BaseModel): + url: str + objective: str + step_by_step: Optional[bool] = False + cloud_driver: Optional[bool] = True + await_completion: Optional[bool] = False + is_public: Optional[bool] = False + viewport_size: Optional[Tuple[int, int]] = None + + +class RunUpdate(BaseModel): + objective: Optional[str] = None + is_public: Optional[bool] = False class LaVague(TrajectoryController): @@ -40,15 +57,23 @@ def request_api( json=json, headers=headers, ) - if response.status_code > 299: + if response.status_code >= 400: raise ApiException(response.text) return response.content - def run(self, url: str, objective: str, step_by_step=False) -> Trajectory: + def run(self, request: RunRequest) -> Trajectory: content = self.request_api( "/runs", "POST", - {"url": url, "objective": objective, "step_by_step": step_by_step}, + request.model_dump(), + ) + return Trajectory.from_data(content, self.parser, self) + + def update(self, run_id: str, request: RunUpdate) -> Trajectory: + content = self.request_api( + f"/runs/{run_id}", + "PATCH", + request.model_dump(), ) return Trajectory.from_data(content, self.parser, self) diff --git a/lavague-sdk/lavague/sdk/exceptions.py b/lavague-sdk/lavague/sdk/exceptions.py index 19fa7e25..078bab28 100644 --- a/lavague-sdk/lavague/sdk/exceptions.py +++ b/lavague-sdk/lavague/sdk/exceptions.py @@ -5,3 +5,13 @@ class DriverException(Exception): class CannotBackException(DriverException): def __init__(self, message="History root reached, cannot go back"): super().__init__(message) + + +class NoPageException(DriverException): + def __init__(self, message="No page loaded"): + super().__init__(message) + + +class ElementNotFoundException(DriverException): + def __init__(self, xpath: str): + super().__init__(f"Element not found: {xpath}") diff --git a/lavague-sdk/lavague/sdk/trajectory/base.py b/lavague-sdk/lavague/sdk/trajectory/base.py index 30fca6d7..752391bb 100644 --- a/lavague-sdk/lavague/sdk/trajectory/base.py +++ b/lavague-sdk/lavague/sdk/trajectory/base.py @@ -33,12 +33,15 @@ def run_to_completion(self): def stop_run(self): self._controller.stop(self.run_id) - self.status = RunStatus.CANCELLED + self.status = RunStatus.INTERRUPTED + self.error_msg = "Run interrupted by user" def iter_actions(self) -> Iterator[Action]: yield from self.actions while self.is_running: - yield self.next_action() + action = self.next_action() + if action is not None: + yield action @classmethod def from_data( diff --git a/lavague-sdk/lavague/sdk/trajectory/model.py b/lavague-sdk/lavague/sdk/trajectory/model.py index 554cae3f..4a79f143 100644 --- a/lavague-sdk/lavague/sdk/trajectory/model.py +++ b/lavague-sdk/lavague/sdk/trajectory/model.py @@ -21,6 +21,7 @@ class TrajectoryData(BaseModel): viewport_size: Tuple[int, int] status: RunStatus actions: List[SerializeAsAny[Action]] + error_msg: Optional[str] = None def write_to_file(self, file_path: str): json_model = self.model_dump_json(indent=2) diff --git a/lavague-sdk/lavague/sdk/utilities/format_utils.py b/lavague-sdk/lavague/sdk/utilities/format_utils.py index 0647d934..c7ba7031 100644 --- a/lavague-sdk/lavague/sdk/utilities/format_utils.py +++ b/lavague-sdk/lavague/sdk/utilities/format_utils.py @@ -1,5 +1,6 @@ import re + def quote_numeric_yaml_values(yaml_string: str) -> str: """Wrap numeric values in quotes in a YAML string. @@ -24,4 +25,4 @@ def replace_value(match): # Replace values that are numeric modified_yaml = re.sub(pattern, replace_value, yaml_string) - return modified_yaml \ No newline at end of file + return modified_yaml diff --git a/lavague-sdk/lavague/sdk/utilities/version_checker.py b/lavague-sdk/lavague/sdk/utilities/version_checker.py index ded0ee2b..8dfd7de1 100644 --- a/lavague-sdk/lavague/sdk/utilities/version_checker.py +++ b/lavague-sdk/lavague/sdk/utilities/version_checker.py @@ -41,6 +41,8 @@ def check_latest_version(): url = "https://pypi.org/pypi/lavague-sdk/json" response = requests.get(url) data = response.json() + if data.get("message") == "Not Found": + return latest_version = data["info"]["version"] if compare_versions(package_version, latest_version) < 0: warnings.warn(