From 076356a3f576ebcbfcf75cc8e03dfe59fdb7a890 Mon Sep 17 00:00:00 2001 From: Nils Uhrberg Date: Mon, 16 Sep 2024 14:45:18 +0200 Subject: [PATCH 1/2] feat: Add the ability to use Github repos that contain symlinks. Since symlinks via the API have no content, the download of the repo did not work. --- .../github_code_repository.py | 69 ++++++++++++++++++- tests/unit/integrations/__init__.py | 0 tests/unit/integrations/github/__init__.py | 0 .../github/test_github_code_repository.py | 47 +++++++++++++ 4 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 tests/unit/integrations/__init__.py create mode 100644 tests/unit/integrations/github/__init__.py create mode 100644 tests/unit/integrations/github/test_github_code_repository.py diff --git a/src/zenml/integrations/github/code_repositories/github_code_repository.py b/src/zenml/integrations/github/code_repositories/github_code_repository.py index 13be68e6a46..4adc4d5edaa 100644 --- a/src/zenml/integrations/github/code_repositories/github_code_repository.py +++ b/src/zenml/integrations/github/code_repositories/github_code_repository.py @@ -15,7 +15,7 @@ import os import re -from typing import List, Optional +from typing import List, Optional, Tuple import requests from github import Github, GithubException @@ -151,7 +151,7 @@ def download_files( raise RuntimeError("Invalid repository subdirectory.") os.makedirs(directory, exist_ok=True) - + tmp_symlinks: List[Tuple[str, str]] = [] for content in contents: local_path = os.path.join(directory, content.name) if content.type == "dir": @@ -160,12 +160,25 @@ def download_files( directory=local_path, repo_sub_directory=content.path, ) + elif content.type == "symlink": + symlink_content = self.github_repo.get_contents( + content.path, ref=commit + ) + symlink_target = symlink_content.raw_data["target"] + tmp_symlinks.append((local_path, symlink_target)) + # As it cannot be assumed at this point that the targets of the + # symlink already exist, the symlinks are first collected here and processed later. else: try: with open(local_path, "wb") as f: f.write(content.decoded_content) except (GithubException, IOError) as e: logger.error("Error processing %s: %s", content.path, e) + for symlink in tmp_symlinks: + symlink_source, symlink_target = symlink + create_symlink_in_local_repo_copy( + symlink_source=symlink_source, symlink_target=symlink_target + ) def get_local_context(self, path: str) -> Optional[LocalRepositoryContext]: """Gets the local repository context. @@ -202,3 +215,55 @@ def check_remote_url(self, url: str) -> bool: return True return False + + +def create_symlink_in_local_repo_copy( + symlink_source: str, symlink_target: str +) -> None: + """This function attempts to create a symbolic link at `local_path` that points to `symlink_target`. + + If a file or directory already exists at `local_path`, it will + be removed before the symbolic link is created. + + Args: + symlink_source: The path where the symbolic link should be created. + symlink_target: The path that the symbolic link should point to. + + Raises: + FileNotFoundError: Informs that the target directory specified by + `symlink_target` does not exist. + PermissionError: Informs that there are insufficient permissions to + create the symbolic link. + NotImplementedError: Informs that symbolic links are not supported on + the current operating system. + OSError: Any other OS-related errors that occur. + """ + try: + if os.path.exists(symlink_source): + if os.path.isdir(symlink_source): + os.rmdir(symlink_source) + else: + os.remove( + symlink_source, + ) + os.symlink(symlink_target, symlink_source) + except FileNotFoundError: + logger.warning( + "The target directory of the symbolic link '%s' does not exist. " + "The creation of the symbolic link is skipped.", + symlink_target, + ) + except PermissionError: + logger.warning( + "You do not have the necessary permissions to create the symbolic link. " + "The creation of the symbolic link '%s' is skipped.", + symlink_source, + ) + except NotImplementedError: + logger.warning( + "Symbolic links are not supported on this operating system. " + "The creation of the symbolic link '%s' is skipped.", + symlink_source, + ) + except OSError as e: + logger.warning("An OS error occurred: %s", e) diff --git a/tests/unit/integrations/__init__.py b/tests/unit/integrations/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/integrations/github/__init__.py b/tests/unit/integrations/github/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/integrations/github/test_github_code_repository.py b/tests/unit/integrations/github/test_github_code_repository.py new file mode 100644 index 00000000000..130a24df3f8 --- /dev/null +++ b/tests/unit/integrations/github/test_github_code_repository.py @@ -0,0 +1,47 @@ +import logging +import os +from tempfile import TemporaryDirectory + +import pytest + +from zenml.integrations.github.code_repositories.github_code_repository import ( + create_symlink_in_local_repo_copy, +) + + +@pytest.fixture() +def tmp_dir() -> str: + with TemporaryDirectory() as tmp: + yield tmp + + +@pytest.mark.parametrize("initial_state", ["file", "directory"]) +def test_create_symlink_in_local_repo_copy(tmp_dir, initial_state: str): + local_path = os.path.join(tmp_dir, "test_symlink") + symlink_target = os.path.join(tmp_dir, "target_folder") + + if initial_state == "file": + with open(local_path) as file: + file.write("content") + elif initial_state == "directory": + os.mkdir(local_path) + os.mkdir(symlink_target) + + create_symlink_in_local_repo_copy( + symlink_source=local_path, symlink_target=symlink_target + ) + assert os.path.islink(local_path) + assert os.readlink(local_path) == str(symlink_target) + + +def test_create_symlink_in_local_repo_copy_target_nonexistent(tmp_dir, caplog): + local_path = os.path.join(tmp_dir, "test_symlink") + symlink_target = os.path.join(tmp_dir, "target_folder") + + with caplog.at_level(logging.WARNING): + create_symlink_in_local_repo_copy( + symlink_source=local_path, symlink_target=symlink_target + ) + + assert not os.path.exists(local_path) + assert "The target directory of the symbolic link" in caplog.text From 88b7c60c8ad4b423611696ca41e82455cce1b725 Mon Sep 17 00:00:00 2001 From: Nils Uhrberg Date: Wed, 2 Oct 2024 16:16:51 +0200 Subject: [PATCH 2/2] refactor: The symlink is no longer deleted and recreated if the symlink already exists --- .../github_code_repository.py | 38 +++++++++---------- .../github/test_github_code_repository.py | 24 ++++++------ 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/src/zenml/integrations/github/code_repositories/github_code_repository.py b/src/zenml/integrations/github/code_repositories/github_code_repository.py index 4adc4d5edaa..6627e08ac92 100644 --- a/src/zenml/integrations/github/code_repositories/github_code_repository.py +++ b/src/zenml/integrations/github/code_repositories/github_code_repository.py @@ -164,8 +164,8 @@ def download_files( symlink_content = self.github_repo.get_contents( content.path, ref=commit ) - symlink_target = symlink_content.raw_data["target"] - tmp_symlinks.append((local_path, symlink_target)) + symlink_src = symlink_content.raw_data["target"] + tmp_symlinks.append((local_path, symlink_src)) # As it cannot be assumed at this point that the targets of the # symlink already exist, the symlinks are first collected here and processed later. else: @@ -175,9 +175,9 @@ def download_files( except (GithubException, IOError) as e: logger.error("Error processing %s: %s", content.path, e) for symlink in tmp_symlinks: - symlink_source, symlink_target = symlink + symlink_dst, symlink_src = symlink create_symlink_in_local_repo_copy( - symlink_source=symlink_source, symlink_target=symlink_target + symlink_dst=symlink_dst, symlink_src=symlink_src ) def get_local_context(self, path: str) -> Optional[LocalRepositoryContext]: @@ -218,20 +218,20 @@ def check_remote_url(self, url: str) -> bool: def create_symlink_in_local_repo_copy( - symlink_source: str, symlink_target: str + symlink_dst: str, symlink_src: str ) -> None: - """This function attempts to create a symbolic link at `local_path` that points to `symlink_target`. + """This function attempts to create a symbolic link at `symlink_dst` that points to `symlink_src`. - If a file or directory already exists at `local_path`, it will + If a file or directory already exists at `symlink_dst`, it will be removed before the symbolic link is created. Args: - symlink_source: The path where the symbolic link should be created. - symlink_target: The path that the symbolic link should point to. + symlink_dst: The path where the symbolic link should be created. + symlink_src: The path that the symbolic link should point to. Raises: FileNotFoundError: Informs that the target directory specified by - `symlink_target` does not exist. + `symlink_dst` does not exist. PermissionError: Informs that there are insufficient permissions to create the symbolic link. NotImplementedError: Informs that symbolic links are not supported on @@ -239,31 +239,27 @@ def create_symlink_in_local_repo_copy( OSError: Any other OS-related errors that occur. """ try: - if os.path.exists(symlink_source): - if os.path.isdir(symlink_source): - os.rmdir(symlink_source) - else: - os.remove( - symlink_source, - ) - os.symlink(symlink_target, symlink_source) + os.symlink(src=symlink_src, dst=symlink_dst) + + except FileExistsError as e: + logger.debug("The symbolic link already exists. %s",e) except FileNotFoundError: logger.warning( "The target directory of the symbolic link '%s' does not exist. " "The creation of the symbolic link is skipped.", - symlink_target, + symlink_src, ) except PermissionError: logger.warning( "You do not have the necessary permissions to create the symbolic link. " "The creation of the symbolic link '%s' is skipped.", - symlink_source, + symlink_dst, ) except NotImplementedError: logger.warning( "Symbolic links are not supported on this operating system. " "The creation of the symbolic link '%s' is skipped.", - symlink_source, + symlink_dst, ) except OSError as e: logger.warning("An OS error occurred: %s", e) diff --git a/tests/unit/integrations/github/test_github_code_repository.py b/tests/unit/integrations/github/test_github_code_repository.py index 130a24df3f8..387642e6182 100644 --- a/tests/unit/integrations/github/test_github_code_repository.py +++ b/tests/unit/integrations/github/test_github_code_repository.py @@ -17,31 +17,31 @@ def tmp_dir() -> str: @pytest.mark.parametrize("initial_state", ["file", "directory"]) def test_create_symlink_in_local_repo_copy(tmp_dir, initial_state: str): - local_path = os.path.join(tmp_dir, "test_symlink") - symlink_target = os.path.join(tmp_dir, "target_folder") + symlink_dst = os.path.join(tmp_dir, "test_symlink") + symlink_src = os.path.join(tmp_dir, "target_folder") if initial_state == "file": - with open(local_path) as file: + with open(symlink_dst) as file: file.write("content") elif initial_state == "directory": - os.mkdir(local_path) - os.mkdir(symlink_target) + os.mkdir(symlink_dst) + os.mkdir(symlink_src) create_symlink_in_local_repo_copy( - symlink_source=local_path, symlink_target=symlink_target + symlink_dst=symlink_dst, symlink_src=symlink_src ) - assert os.path.islink(local_path) - assert os.readlink(local_path) == str(symlink_target) + assert os.path.islink(symlink_dst) + assert os.readlink(symlink_dst) == str(symlink_src) def test_create_symlink_in_local_repo_copy_target_nonexistent(tmp_dir, caplog): - local_path = os.path.join(tmp_dir, "test_symlink") - symlink_target = os.path.join(tmp_dir, "target_folder") + symlink_dst = os.path.join(tmp_dir, "test_symlink") + symlink_src = os.path.join(tmp_dir, "target_folder") with caplog.at_level(logging.WARNING): create_symlink_in_local_repo_copy( - symlink_source=local_path, symlink_target=symlink_target + symlink_dst=symlink_dst, symlink_src=symlink_src ) - assert not os.path.exists(local_path) + assert not os.path.exists(symlink_dst) assert "The target directory of the symbolic link" in caplog.text