diff --git a/lavague-sdk/lavague/sdk/base_driver.py b/lavague-sdk/lavague/sdk/base_driver.py index f8073309..ec398851 100644 --- a/lavague-sdk/lavague/sdk/base_driver.py +++ b/lavague-sdk/lavague/sdk/base_driver.py @@ -20,6 +20,52 @@ class InteractionType(Enum): r_get_xpaths_from_html = r'xpath=["\'](.*?)["\']' +class ScrollDirection(Enum): + """Enum for the different scroll directions. Value is (x, y, dimension_index)""" + + LEFT = (-1, 0, 0) + RIGHT = (1, 0, 0) + UP = (0, -1, 1) + DOWN = (0, 1, 1) + + def get_scroll_xy( + self, dimension: List[float], scroll_factor: float = 0.75 + ) -> Tuple[int, int]: + size = dimension[self.value[2]] + return ( + round(self.value[0] * size * scroll_factor), + round(self.value[1] * size * scroll_factor), + ) + + def get_page_script(self, scroll_factor: float = 0.75) -> str: + return f"window.scrollBy({self.value[0] * scroll_factor} * window.innerWidth, {self.value[1] * scroll_factor} * window.innerHeight);" + + def get_script_element_is_scrollable(self) -> str: + match self: + case ScrollDirection.UP: + return "return arguments[0].scrollTop > 0" + case ScrollDirection.DOWN: + return "return arguments[0].scrollTop + arguments[0].clientHeight + 1 < arguments[0].scrollHeight" + case ScrollDirection.LEFT: + return "return arguments[0].scrollLeft > 0" + case ScrollDirection.RIGHT: + return "return arguments[0].scrollLeft + arguments[0].clientWidth + 1 < arguments[0].scrollWidth" + + def get_script_page_is_scrollable(self) -> str: + match self: + case ScrollDirection.UP: + return "return window.scrollY > 0" + case ScrollDirection.DOWN: + return "return window.innerHeight + window.scrollY + 1 < document.body.scrollHeight" + case ScrollDirection.LEFT: + return "return window.scrollX > 0" + case ScrollDirection.RIGHT: + return "return window.innerWidth + window.scrollX + 1 < document.body.scrollWidth" + + @classmethod + def from_string(cls, name: str) -> "ScrollDirection": + return cls[name.upper().strip()] + class BaseDriver(ABC): def __init__(self, url: Optional[str], init_function: Optional[Callable[[], Any]]): @@ -223,6 +269,16 @@ def execute_script(self, js_code: str, *args) -> Any: """Exec js script in DOM""" pass + @abstractmethod + def scroll( + self, + xpath_anchor: Optional[str], + direction: ScrollDirection, + scroll_factor=0.75, + ): + pass + + # TODO: Remove these methods as they are not used @abstractmethod def scroll_up(self): pass @@ -264,8 +320,16 @@ def get_obs(self) -> dict: # If the last operation was to scan the whole page, we reset the flag self.previously_scanned = False + # We add labels to the scrollable elements + i_scroll = self.get_possible_interactions(types=[InteractionType.SCROLL]) + scrollables_xpaths = list(i_scroll.keys()) + + self.remove_highlight() + self.highlight_nodes(scrollables_xpaths, label=True) + # We take a screenshot and computes its hash to see if it already exists self.save_screenshot(current_screenshot_folder) + self.remove_highlight() url = self.get_url() html = self.get_html() @@ -425,51 +489,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass -class ScrollDirection(Enum): - """Enum for the different scroll directions. Value is (x, y, dimension_index)""" - - LEFT = (-1, 0, 0) - RIGHT = (1, 0, 0) - UP = (0, -1, 1) - DOWN = (0, 1, 1) - - def get_scroll_xy( - self, dimension: List[float], scroll_factor: float = 0.75 - ) -> Tuple[int, int]: - size = dimension[self.value[2]] - return ( - round(self.value[0] * size * scroll_factor), - round(self.value[1] * size * scroll_factor), - ) - - def get_page_script(self, scroll_factor: float = 0.75) -> str: - return f"window.scrollBy({self.value[0] * scroll_factor} * window.innerWidth, {self.value[1] * scroll_factor} * window.innerHeight);" - - def get_script_element_is_scrollable(self) -> str: - match self: - case ScrollDirection.UP: - return "return arguments[0].scrollTop > 0" - case ScrollDirection.DOWN: - return "return arguments[0].scrollTop + arguments[0].clientHeight + 1 < arguments[0].scrollHeight" - case ScrollDirection.LEFT: - return "return arguments[0].scrollLeft > 0" - case ScrollDirection.RIGHT: - return "return arguments[0].scrollLeft + arguments[0].clientWidth + 1 < arguments[0].scrollWidth" - - def get_script_page_is_scrollable(self) -> str: - match self: - case ScrollDirection.UP: - return "return window.scrollY > 0" - case ScrollDirection.DOWN: - return "return window.innerHeight + window.scrollY + 1 < document.body.scrollHeight" - case ScrollDirection.LEFT: - return "return window.scrollX > 0" - case ScrollDirection.RIGHT: - return "return window.innerWidth + window.scrollX + 1 < document.body.scrollWidth" - - @classmethod - def from_string(cls, name: str) -> "ScrollDirection": - return cls[name.upper().strip()] def js_wrap_function_call(fn: str):