diff --git a/RELEASE.md b/RELEASE.md index bf9e4cff8b..3261588f52 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -17,6 +17,7 @@ Please follow the established format: - Improve `kedro viz build` usage documentation (#2126) - Fix unserializable parameters value (#2122) +- Replace `watchgod` library with `watchfiles` and improve autoreload file watching filter (#2134) - Display full dataset type with library prefix in metadata panel (#2136) - Enable SQLite WAL mode for Azure ML to fix database locking issues (#2131) - Replace `flake8`, `isort`, `pylint` and `black` by `ruff` (#2149) diff --git a/package/features/steps/lower_requirements.txt b/package/features/steps/lower_requirements.txt index 1a149e7ad9..ab38585acf 100644 --- a/package/features/steps/lower_requirements.txt +++ b/package/features/steps/lower_requirements.txt @@ -3,7 +3,7 @@ fastapi==0.100.0 fsspec==2021.4 aiofiles==22.1.0 uvicorn[standard]==0.22.0 -watchgod==0.8.2 +watchfiles==0.24.0 plotly==4.8 packaging==23.0 pandas==1.3; python_version < '3.10' @@ -16,3 +16,4 @@ secure==0.3.0 # numpy 2.0 breaks with old versions of pandas and this # could be removed when the lowest version supported is updated numpy==1.26.4 +pathspec==0.12.1 diff --git a/package/kedro_viz/autoreload_file_filter.py b/package/kedro_viz/autoreload_file_filter.py new file mode 100644 index 0000000000..f8b13c6237 --- /dev/null +++ b/package/kedro_viz/autoreload_file_filter.py @@ -0,0 +1,88 @@ +""" +This module provides a custom file filter for autoreloading that filters out files based on allowed +file extensions and patterns specified in a .gitignore file. +""" + +import logging +from pathlib import Path +from typing import Optional, Set + +from pathspec import GitIgnoreSpec +from watchfiles import Change, DefaultFilter + +logger = logging.getLogger(__name__) + + +class AutoreloadFileFilter(DefaultFilter): + """ + Custom file filter for autoreloading that extends DefaultFilter. + Filters out files based on allowed file extensions and patterns specified in a .gitignore file. + """ + + allowed_extensions: Set[str] = {".py", ".yml", ".yaml", ".json"} + + def __init__(self, base_path: Optional[Path] = None): + """ + Initialize the AutoreloadFileFilter. + + Args: + base_path (Optional[Path]): The base path to set as the current working directory + for the filter. + """ + self.cwd = base_path or Path.cwd() + + # Call the superclass constructor + super().__init__() + + # Load .gitignore patterns + gitignore_path = self.cwd / ".gitignore" + try: + with open(gitignore_path, "r", encoding="utf-8") as gitignore_file: + ignore_patterns = gitignore_file.read().splitlines() + self.gitignore_spec: Optional[GitIgnoreSpec] = GitIgnoreSpec.from_lines( + "gitwildmatch", ignore_patterns + ) + except FileNotFoundError: + self.gitignore_spec = None + + def __call__(self, change: Change, path: str) -> bool: + """ + Determine whether a file change should be processed. + + Args: + change (Change): The type of change detected. + path (str): The path to the file that changed. + + Returns: + bool: True if the file should be processed, False otherwise. + """ + if not super().__call__(change, path): + logger.debug("Filtered out by DefaultFilter: %s", path) + return False + + path_obj = Path(path) + + # Exclude files matching .gitignore patterns + try: + relative_path = path_obj.resolve().relative_to(self.cwd.resolve()) + except ValueError: + logger.debug("Path not relative to CWD: %s", path) + return False + + try: + if self.gitignore_spec and self.gitignore_spec.match_file( + str(relative_path) + ): + logger.debug("Filtered out by .gitignore: %s", relative_path) + return False + # ruff: noqa: BLE001 + except Exception as exc: + logger.debug("Exception during .gitignore matching: %s", exc) + return True # Pass the file if .gitignore matching fails + + # Include only files with allowed extensions + if path_obj.suffix in self.allowed_extensions: + logger.debug("Allowed file: %s", path) + return True + logger.debug("Filtered out by allowed_extensions: %s", path_obj.suffix) + return False diff --git a/package/kedro_viz/launchers/cli/run.py b/package/kedro_viz/launchers/cli/run.py index e7dd08b408..b2e74a48be 100644 --- a/package/kedro_viz/launchers/cli/run.py +++ b/package/kedro_viz/launchers/cli/run.py @@ -7,6 +7,7 @@ from kedro.framework.cli.project import PARAMS_ARG_HELP from kedro.framework.cli.utils import _split_params +from kedro_viz.autoreload_file_filter import AutoreloadFileFilter from kedro_viz.constants import DEFAULT_HOST, DEFAULT_PORT from kedro_viz.launchers.cli.main import viz @@ -162,21 +163,25 @@ def run( "extra_params": params, "is_lite": lite, } + + process_context = multiprocessing.get_context("spawn") if autoreload: - from watchgod import RegExpWatcher, run_process + from watchfiles import run_process + run_process_args = [str(kedro_project_path)] run_process_kwargs = { - "path": kedro_project_path, "target": run_server, "kwargs": run_server_kwargs, - "watcher_cls": RegExpWatcher, - "watcher_kwargs": {"re_files": r"^.*(\.yml|\.yaml|\.py|\.json)$"}, + "watch_filter": AutoreloadFileFilter(), } - viz_process = multiprocessing.Process( - target=run_process, daemon=False, kwargs={**run_process_kwargs} + viz_process = process_context.Process( + target=run_process, + daemon=False, + args=run_process_args, + kwargs={**run_process_kwargs}, ) else: - viz_process = multiprocessing.Process( + viz_process = process_context.Process( target=run_server, daemon=False, kwargs={**run_server_kwargs} ) diff --git a/package/kedro_viz/launchers/jupyter.py b/package/kedro_viz/launchers/jupyter.py index 22af9fb99a..cd39610ab3 100644 --- a/package/kedro_viz/launchers/jupyter.py +++ b/package/kedro_viz/launchers/jupyter.py @@ -14,8 +14,9 @@ import IPython from IPython.display import HTML, display from kedro.framework.project import PACKAGE_NAME -from watchgod import RegExpWatcher, run_process +from watchfiles import run_process +from kedro_viz.autoreload_file_filter import AutoreloadFileFilter from kedro_viz.launchers.utils import _check_viz_up, _wait_for from kedro_viz.server import DEFAULT_HOST, DEFAULT_PORT, run_server @@ -146,15 +147,17 @@ def run_viz(args: str = "", local_ns: Dict[str, Any] = None) -> None: } process_context = multiprocessing.get_context("spawn") if autoreload: + run_process_args = [str(project_path)] run_process_kwargs = { - "path": project_path, "target": run_server, "kwargs": run_server_kwargs, - "watcher_cls": RegExpWatcher, - "watcher_kwargs": {"re_files": r"^.*(\.yml|\.yaml|\.py|\.json)$"}, + "watch_filter": AutoreloadFileFilter(), } viz_process = process_context.Process( - target=run_process, daemon=False, kwargs={**run_process_kwargs} + target=run_process, + daemon=False, + args=run_process_args, + kwargs={**run_process_kwargs}, ) else: viz_process = process_context.Process( diff --git a/package/kedro_viz/server.py b/package/kedro_viz/server.py index 251bb32b6b..db95289b6d 100644 --- a/package/kedro_viz/server.py +++ b/package/kedro_viz/server.py @@ -8,6 +8,7 @@ from kedro.io import DataCatalog from kedro.pipeline import Pipeline +from kedro_viz.autoreload_file_filter import AutoreloadFileFilter from kedro_viz.constants import DEFAULT_HOST, DEFAULT_PORT from kedro_viz.data_access import DataAccessManager, data_access_manager from kedro_viz.database import make_db_session_factory @@ -143,7 +144,7 @@ def run_server( import argparse import multiprocessing - from watchgod import RegExpWatcher, run_process + from watchfiles import run_process parser = argparse.ArgumentParser(description="Launch a development viz server") parser.add_argument("project_path", help="Path to a Kedro project") @@ -157,20 +158,24 @@ def run_server( project_path = (Path.cwd() / args.project_path).absolute() + run_process_args = [str(project_path)] run_process_kwargs = { - "path": project_path, "target": run_server, "kwargs": { "host": args.host, "port": args.port, "project_path": str(project_path), }, - "watcher_cls": RegExpWatcher, - "watcher_kwargs": {"re_files": r"^.*(\.yml|\.yaml|\.py|\.json)$"}, + "watch_filter": AutoreloadFileFilter(), } - viz_process = multiprocessing.Process( - target=run_process, daemon=False, kwargs={**run_process_kwargs} + process_context = multiprocessing.get_context("spawn") + + viz_process = process_context.Process( + target=run_process, + daemon=False, + args=run_process_args, + kwargs={**run_process_kwargs}, ) display_cli_message("Starting Kedro Viz ...", "green") diff --git a/package/requirements.txt b/package/requirements.txt index caf3fa63ea..16c7890f5d 100644 --- a/package/requirements.txt +++ b/package/requirements.txt @@ -15,4 +15,5 @@ secure>=0.3.0 sqlalchemy>=1.4, <3 strawberry-graphql>=0.192.0, <1.0 uvicorn[standard]>=0.30.0, <1.0 -watchgod>=0.8.2, <1.0 +watchfiles>=0.24.0 +pathspec>=0.12.1 \ No newline at end of file diff --git a/package/test_requirements.txt b/package/test_requirements.txt index 3260a24806..c2ac8e7c78 100644 --- a/package/test_requirements.txt +++ b/package/test_requirements.txt @@ -19,6 +19,7 @@ sqlalchemy-stubs~=0.4 strawberry-graphql[cli]>=0.99.0, <1.0 trufflehog~=2.2 httpx~=0.27.0 +pathspec>=0.12.1 # mypy types-aiofiles==0.1.3 diff --git a/package/tests/test_autoreload_file_filter.py b/package/tests/test_autoreload_file_filter.py new file mode 100644 index 0000000000..d5c9fb2ff7 --- /dev/null +++ b/package/tests/test_autoreload_file_filter.py @@ -0,0 +1,141 @@ +import logging +import shutil +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from watchfiles import Change, DefaultFilter + +from kedro_viz.autoreload_file_filter import AutoreloadFileFilter + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def file_filter(tmp_path): + """ + Fixture to create a temporary .gitignore file and initialize the AutoreloadFileFilter + with the test directory as the base path. + """ + # Create a .gitignore file + gitignore_path = tmp_path / ".gitignore" + gitignore_path.write_text("ignored.py\n") + + # Initialize the filter with the test directory as base_path + return AutoreloadFileFilter(base_path=tmp_path) + + +def test_no_gitignore(tmp_path): + """ + Test that a file passes the filter when the .gitignore file is missing. + """ + gitignored_file = tmp_path / "ignored.py" + gitignored_file.touch() + + # Initialize the filter without a .gitignore file + gitignore_path = tmp_path / ".gitignore" + if gitignore_path.exists(): + gitignore_path.unlink() + file_filter = AutoreloadFileFilter(base_path=tmp_path) + + result = file_filter(Change.modified, str(gitignored_file)) + assert result, "File should pass the filter when .gitignore is missing" + + +def test_gitignore_exception(file_filter, tmp_path): + """ + Test that a file passes the filter if an exception occurs during .gitignore matching. + """ + allowed_file = tmp_path / "test.py" + allowed_file.touch() + + with patch( + "pathspec.PathSpec.match_file", side_effect=Exception("Mocked exception") + ): + result = file_filter(Change.modified, str(allowed_file)) + assert result, "Filter should pass the file if .gitignore matching fails" + + +def test_allowed_file(file_filter, tmp_path): + """ + Test that a file with an allowed extension passes the filter. + """ + allowed_file = tmp_path / "test.py" + allowed_file.touch() + + result = file_filter(Change.modified, str(allowed_file)) + assert result, "Allowed file should pass the filter" + + +def test_disallowed_file(file_filter, tmp_path): + """ + Test that a file with a disallowed extension does not pass the filter. + """ + disallowed_file = tmp_path / "test.txt" + disallowed_file.touch() + + result = file_filter(Change.modified, str(disallowed_file)) + assert not result, "Disallowed file should not pass the filter" + + +def test_gitignored_file(file_filter, tmp_path): + """ + Test that a file listed in the .gitignore file does not pass the filter. + """ + gitignored_file = tmp_path / "ignored.py" + gitignored_file.touch() + + result = file_filter(Change.modified, str(gitignored_file)) + assert not result, "Gitignored file should not pass the filter" + + +def test_non_relative_path(file_filter): + """ + Test that a file outside the current working directory does not pass the filter. + """ + original_cwd = Path.cwd().parent # Go up one directory + outside_file = original_cwd / "outside.py" + outside_file.touch() + + result = file_filter(Change.modified, str(outside_file)) + assert not result, "File outside the CWD should not pass the filter" + + # Cleanup + outside_file.unlink() + + +def test_no_allowed_extension(file_filter, tmp_path): + """ + Test that a file without an allowed extension does not pass the filter. + """ + no_extension_file = tmp_path / "no_extension" + no_extension_file.touch() + + result = file_filter(Change.modified, str(no_extension_file)) + assert not result, "File without allowed extension should not pass the filter" + + +def test_directory_path(file_filter, tmp_path): + """ + Test that a directory does not pass the filter. + """ + directory_path = tmp_path / "some_directory" + directory_path.mkdir() + + result = file_filter(Change.modified, str(directory_path)) + assert not result, "Directories should not pass the filter" + + +def test_filtered_out_by_default_filter(file_filter, tmp_path, mocker): + """ + Test that a file is filtered out by the DefaultFilter. + """ + filtered_file = tmp_path / "filtered.py" + filtered_file.touch() + + # Mock the super().__call__ method to return False + mocker.patch.object(DefaultFilter, "__call__", return_value=False) + + result = file_filter(Change.modified, str(filtered_file)) + assert not result, "File should be filtered out by DefaultFilter" diff --git a/package/tests/test_launchers/test_cli/test_run.py b/package/tests/test_launchers/test_cli/test_run.py index b2d5c59b39..86adae92f6 100644 --- a/package/tests/test_launchers/test_cli/test_run.py +++ b/package/tests/test_launchers/test_cli/test_run.py @@ -4,9 +4,10 @@ import requests from click.testing import CliRunner from packaging.version import parse -from watchgod import RegExpWatcher, run_process +from watchfiles import run_process from kedro_viz import __version__ +from kedro_viz.autoreload_file_filter import AutoreloadFileFilter from kedro_viz.launchers.cli import main from kedro_viz.launchers.cli.run import _VIZ_PROCESSES from kedro_viz.launchers.utils import _PYPROJECT @@ -205,7 +206,10 @@ def test_kedro_viz_command_run_server( patched_check_viz_up, patched_start_browser, ): - process_init = mocker.patch("multiprocessing.Process") + mock_process_context = mocker.patch("multiprocessing.get_context") + mock_context_instance = mocker.Mock() + mock_process_context.return_value = mock_context_instance + mock_process = mocker.patch.object(mock_context_instance, "Process") runner = CliRunner() # Reduce the timeout argument from 600 to 1 to make test run faster. @@ -222,7 +226,7 @@ def test_kedro_viz_command_run_server( with runner.isolated_filesystem(): runner.invoke(main.viz_cli, command_options) - process_init.assert_called_once_with( + mock_process.assert_called_once_with( target=run_server, daemon=False, kwargs={**run_server_args} ) @@ -340,9 +344,15 @@ def test_kedro_viz_command_should_not_log_if_pypi_is_down( mock_click_echo.assert_has_calls(mock_click_echo_calls) def test_kedro_viz_command_with_autoreload( - self, mocker, mock_project_path, patched_check_viz_up, patched_start_browser + self, mocker, tmp_path, patched_check_viz_up, patched_start_browser ): - process_init = mocker.patch("multiprocessing.Process") + mock_process_context = mocker.patch("multiprocessing.get_context") + mock_context_instance = mocker.Mock() + mock_process_context.return_value = mock_context_instance + mock_process = mocker.patch.object(mock_context_instance, "Process") + mock_tmp_path = tmp_path / "tmp" + mock_tmp_path.mkdir() + mock_path = mock_tmp_path / "project_path" # Reduce the timeout argument from 600 to 1 to make test run faster. mocker.patch( @@ -351,14 +361,14 @@ def test_kedro_viz_command_with_autoreload( # Mock finding kedro project mocker.patch( "kedro_viz.launchers.utils._find_kedro_project", - return_value=mock_project_path, + return_value=mock_path, ) runner = CliRunner() with runner.isolated_filesystem(): runner.invoke(main.viz_cli, ["viz", "run", "--autoreload"]) + run_process_args = [str(mock_path)] run_process_kwargs = { - "path": mock_project_path, "target": run_server, "kwargs": { "host": "127.0.0.1", @@ -367,18 +377,20 @@ def test_kedro_viz_command_with_autoreload( "save_file": None, "pipeline_name": None, "env": None, + "project_path": mock_path, "autoreload": True, - "project_path": mock_project_path, "include_hooks": False, "package_name": None, "extra_params": {}, "is_lite": False, }, - "watcher_cls": RegExpWatcher, - "watcher_kwargs": {"re_files": "^.*(\\.yml|\\.yaml|\\.py|\\.json)$"}, + "watch_filter": mocker.ANY, } - process_init.assert_called_once_with( - target=run_process, daemon=False, kwargs={**run_process_kwargs} + mock_process.assert_called_once_with( + target=run_process, + daemon=False, + args=run_process_args, + kwargs={**run_process_kwargs}, ) assert run_process_kwargs["kwargs"]["port"] in _VIZ_PROCESSES diff --git a/package/tests/test_launchers/test_jupyter.py b/package/tests/test_launchers/test_jupyter.py index dd489778ca..485e7ff890 100644 --- a/package/tests/test_launchers/test_jupyter.py +++ b/package/tests/test_launchers/test_jupyter.py @@ -140,6 +140,7 @@ def test_run_viz_with_autoreload(self, mocker, patched_check_viz_up): mock_process.assert_called_once_with( target=mocker.ANY, daemon=False, # No daemon for autoreload + args=mocker.ANY, kwargs=mocker.ANY, )