Skip to content

Commit

Permalink
Modified scrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
dhuynh95 committed Oct 2, 2024
1 parent 5599da2 commit 5113f98
Showing 1 changed file with 64 additions and 45 deletions.
109 changes: 64 additions & 45 deletions lavague-sdk/lavague/sdk/base_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@ class InteractionType(Enum):

r_get_xpaths_from_html = r'xpath=["\'](.*?)["\']'

class ScrollDirection(Enum):
"""Enum for the different scroll directions. Value is (x, y, dimension_index)"""

LEFT = (-1, 0, 0)
RIGHT = (1, 0, 0)
UP = (0, -1, 1)
DOWN = (0, 1, 1)

def get_scroll_xy(
self, dimension: List[float], scroll_factor: float = 0.75
) -> Tuple[int, int]:
size = dimension[self.value[2]]
return (
round(self.value[0] * size * scroll_factor),
round(self.value[1] * size * scroll_factor),
)

def get_page_script(self, scroll_factor: float = 0.75) -> str:
return f"window.scrollBy({self.value[0] * scroll_factor} * window.innerWidth, {self.value[1] * scroll_factor} * window.innerHeight);"

def get_script_element_is_scrollable(self) -> str:
match self:
case ScrollDirection.UP:
return "return arguments[0].scrollTop > 0"
case ScrollDirection.DOWN:
return "return arguments[0].scrollTop + arguments[0].clientHeight + 1 < arguments[0].scrollHeight"
case ScrollDirection.LEFT:
return "return arguments[0].scrollLeft > 0"
case ScrollDirection.RIGHT:
return "return arguments[0].scrollLeft + arguments[0].clientWidth + 1 < arguments[0].scrollWidth"

def get_script_page_is_scrollable(self) -> str:
match self:
case ScrollDirection.UP:
return "return window.scrollY > 0"
case ScrollDirection.DOWN:
return "return window.innerHeight + window.scrollY + 1 < document.body.scrollHeight"
case ScrollDirection.LEFT:
return "return window.scrollX > 0"
case ScrollDirection.RIGHT:
return "return window.innerWidth + window.scrollX + 1 < document.body.scrollWidth"

@classmethod
def from_string(cls, name: str) -> "ScrollDirection":
return cls[name.upper().strip()]


class BaseDriver(ABC):
def __init__(self, url: Optional[str], init_function: Optional[Callable[[], Any]]):
Expand Down Expand Up @@ -223,6 +269,16 @@ def execute_script(self, js_code: str, *args) -> Any:
"""Exec js script in DOM"""
pass

@abstractmethod
def scroll(
self,
xpath_anchor: Optional[str],
direction: ScrollDirection,
scroll_factor=0.75,
):
pass

# TODO: Remove these methods as they are not used
@abstractmethod
def scroll_up(self):
pass
Expand Down Expand Up @@ -264,8 +320,16 @@ def get_obs(self) -> dict:
# If the last operation was to scan the whole page, we reset the flag
self.previously_scanned = False

# We add labels to the scrollable elements
i_scroll = self.get_possible_interactions(types=[InteractionType.SCROLL])
scrollables_xpaths = list(i_scroll.keys())

self.remove_highlight()
self.highlight_nodes(scrollables_xpaths, label=True)

# We take a screenshot and computes its hash to see if it already exists
self.save_screenshot(current_screenshot_folder)
self.remove_highlight()

url = self.get_url()
html = self.get_html()
Expand Down Expand Up @@ -425,51 +489,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pass


class ScrollDirection(Enum):
"""Enum for the different scroll directions. Value is (x, y, dimension_index)"""

LEFT = (-1, 0, 0)
RIGHT = (1, 0, 0)
UP = (0, -1, 1)
DOWN = (0, 1, 1)

def get_scroll_xy(
self, dimension: List[float], scroll_factor: float = 0.75
) -> Tuple[int, int]:
size = dimension[self.value[2]]
return (
round(self.value[0] * size * scroll_factor),
round(self.value[1] * size * scroll_factor),
)

def get_page_script(self, scroll_factor: float = 0.75) -> str:
return f"window.scrollBy({self.value[0] * scroll_factor} * window.innerWidth, {self.value[1] * scroll_factor} * window.innerHeight);"

def get_script_element_is_scrollable(self) -> str:
match self:
case ScrollDirection.UP:
return "return arguments[0].scrollTop > 0"
case ScrollDirection.DOWN:
return "return arguments[0].scrollTop + arguments[0].clientHeight + 1 < arguments[0].scrollHeight"
case ScrollDirection.LEFT:
return "return arguments[0].scrollLeft > 0"
case ScrollDirection.RIGHT:
return "return arguments[0].scrollLeft + arguments[0].clientWidth + 1 < arguments[0].scrollWidth"

def get_script_page_is_scrollable(self) -> str:
match self:
case ScrollDirection.UP:
return "return window.scrollY > 0"
case ScrollDirection.DOWN:
return "return window.innerHeight + window.scrollY + 1 < document.body.scrollHeight"
case ScrollDirection.LEFT:
return "return window.scrollX > 0"
case ScrollDirection.RIGHT:
return "return window.innerWidth + window.scrollX + 1 < document.body.scrollWidth"

@classmethod
def from_string(cls, name: str) -> "ScrollDirection":
return cls[name.upper().strip()]


def js_wrap_function_call(fn: str):
Expand Down

0 comments on commit 5113f98

Please sign in to comment.