From 1a671bb32ec25353a94a96d187c0610f11045662 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 09:14:59 +0200 Subject: [PATCH 01/61] exclude test_data/samples/* from black, isort, mypy --- pyproject.toml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 65de09a..f84053b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,15 @@ [tool.black] +force-exclude = 'test_data/samples/*' line_length = 79 [tool.isort] -profile = "black" -multi_line_output = 3 +extend_skip = ["test_data/samples"] line_length = 79 +multi_line_output = 3 +profile = "black" + +[tool.mypy] +exclude = "test_data/samples/*" [[tool.mypy.overrides]] module = "pathspec" From dbce6248f27255b1e06454eef7b15e663fcd1419 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 09:15:29 +0200 Subject: [PATCH 02/61] remove black, isort, mypy config from tox.ini --- tox.ini | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tox.ini b/tox.ini index b90444e..0ba85cd 100644 --- a/tox.ini +++ b/tox.ini @@ -12,14 +12,14 @@ deps = black skip_install = True commands = - black --check src/ssort tests setup.py + black --check . [testenv:isort] deps = isort skip_install = True commands = - isort --check-only src/ssort tests setup.py + isort --check-only . [testenv:ssort] commands = @@ -48,4 +48,4 @@ deps = types-setuptools skip_install = True commands = - mypy src/ssort tests setup.py + mypy . From 056f6e23c3ecb70cfa4078bef53f5cf36674cfe1 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 09:16:36 +0200 Subject: [PATCH 03/61] remove black, isort, mypy config from github ci --- .github/workflows/ci.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0904394..02b9077 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -66,7 +66,7 @@ jobs: pip install black - name: Run black run: | - black --check --diff src/ tests/ setup.py + black --check --diff . isort: name: "ISort" @@ -83,7 +83,7 @@ jobs: pip install isort - name: Run isort run: | - isort --check-only --diff src/ tests/ setup.py + isort --check-only --diff . ssort: name: "SSort" @@ -155,4 +155,4 @@ jobs: pip install types-setuptools - name: Run mypy run: | - mypy src/ssort tests setup.py + mypy . From ac2840cd15210ad35c01366f7188ac8c8c579d03 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 09:21:54 +0200 Subject: [PATCH 04/61] remove testing of setup.py in tox.ini --- .github/workflows/ci.yaml | 6 +++--- tox.ini | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 02b9077..a18b3df 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -100,7 +100,7 @@ jobs: pip install -e . - name: Run ssort run: | - ssort --check --diff src/ tests/ setup.py + ssort --check --diff src/ tests/ pyflakes: name: "PyFlakes" @@ -117,7 +117,7 @@ jobs: pip install pyflakes - name: Run pyflakes run: | - pyflakes src/ tests/ setup.py + pyflakes src/ tests/ pylint: name: "PyLint" @@ -136,7 +136,7 @@ jobs: pip install pylint - name: Run pylint run: | - pylint -E src/ tests/ setup.py + pylint -E src/ tests/ mypy: name: "Mypy" diff --git a/tox.ini b/tox.ini index 0ba85cd..78ed059 100644 --- a/tox.ini +++ b/tox.ini @@ -23,14 +23,14 @@ commands = [testenv:ssort] commands = - ssort --check --diff src/ssort tests setup.py + ssort --check --diff src/ssort tests [testenv:pyflakes] deps = pyflakes skip_install = True commands = - pyflakes src/ssort tests setup.py + pyflakes src/ssort tests [testenv:pylint] deps = @@ -39,7 +39,7 @@ deps = extras= test commands = - pylint -E src/ssort tests setup.py + pylint -E src/ssort tests [testenv:mypy] deps = From 5b8776ce88928a8f0e5d9d141fbc1f13dc1d0cdb Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 09:22:15 +0200 Subject: [PATCH 05/61] remove setup.py --- setup.py | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 setup.py diff --git a/setup.py b/setup.py deleted file mode 100644 index 435df65..0000000 --- a/setup.py +++ /dev/null @@ -1,8 +0,0 @@ -import os -import sys - -import setuptools - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) # noqa - -setuptools.setup() From d50431a814b5ae84b78c5d60492419917e3be8e2 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 09:26:02 +0200 Subject: [PATCH 06/61] move config from setup.cfg to pyproject.toml --- pyproject.toml | 63 +++++++++++++++++++++++++++++++++++++++++++++++++- setup.cfg | 43 ---------------------------------- 2 files changed, 62 insertions(+), 44 deletions(-) delete mode 100644 setup.cfg diff --git a/pyproject.toml b/pyproject.toml index f84053b..5ab348e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,56 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = [ + "setuptools>=61.2" +] + +[project] +authors = [ + {email = "bwhmather@bwhmather.com", name = "Ben Mather"} +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Software Development :: Quality Assurance" +] +dependencies = [ + "pathspec >=0.9.0" +] +description = "The python statement sorter" +dynamic = [ + "version" +] +maintainers = [ + {email = "bwhmather@bwhmather.com", name = "Ben Mather"} +] +name = "ssort" +readme = "README.rst" +requires-python = ">=3.8" + +[project.license] +text = "MIT" + +[project.scripts] +ssort = "ssort._main:main" + +[project.urls] +Homepage = "https://github.com/bwhmather/ssort" + [tool.black] force-exclude = 'test_data/samples/*' line_length = 79 +[tool.distutils.bdist_wheel] +universal = 1 + [tool.isort] extend_skip = ["test_data/samples"] line_length = 79 @@ -12,5 +61,17 @@ profile = "black" exclude = "test_data/samples/*" [[tool.mypy.overrides]] -module = "pathspec" ignore_missing_imports = true +module = "pathspec" + +[tool.setuptools] +include-package-data = false +license-files = [ + "LICENSE" +] + +[tool.setuptools.dynamic.version] +attr = "ssort.__version__" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 1fa6964..0000000 --- a/setup.cfg +++ /dev/null @@ -1,43 +0,0 @@ -[bdist_wheel] -universal = 1 - -[metadata] -name = ssort -version = attr: ssort.__version__ -license = MIT -license_file = LICENSE -description = The python statement sorter -long_description = file: README.rst -author = Ben Mather -author_email = bwhmather@bwhmather.com -maintainer = Ben Mather -maintainer_email = bwhmather@bwhmather.com -url = https://github.com/bwhmather/ssort -classifiers = - Development Status :: 5 - Production/Stable - Environment :: Console - Intended Audience :: Developers - License :: OSI Approved :: MIT License - Programming Language :: Python - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Topic :: Software Development :: Libraries :: Python Modules - Topic :: Software Development :: Quality Assurance - -[options] -package_dir= - =src -packages = find: - -install_requires = - pathspec >=0.9.0 -python_requires = >=3.8 - -[options.packages.find] -where = src - -[options.entry_points] -console_scripts = - ssort = ssort._main:main From 0c8cf6848303b31f2db243104b07360ad7f79a9e Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 09:41:25 +0200 Subject: [PATCH 07/61] remove mention of setup.py in readme --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index cf1d4b2..8d9e354 100644 --- a/README.rst +++ b/README.rst @@ -104,7 +104,7 @@ We recommend that you reformat using `isort `_ a .. code:: bash - $ ssort src/ tests/ setup.py; isort src/ tests/ setup.py; black src/ tests/ setup.py + $ ssort src/ tests/; isort src/ tests/ setup.py; black src/ tests/ setup.py You can also setup ssort to run automatically before commit by setting up `pre-commit `_, and registering ssort in your `.pre-commit-config.yaml`. From 556e6748c188f9e9097b10235e9c44ae80282460 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 12:03:14 +0200 Subject: [PATCH 08/61] update tox.ini for PEP517 compatibility --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 78ed059..5c52c21 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,6 @@ [tox] envlist = py38,py39,py310,black,isort,ssort,pyflakes,pylint,mypy +isolated_build = true [testenv] deps = From 8c4190a2b8f24777a3710425cd95b5df097f1c5d Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 09:25:11 +0200 Subject: [PATCH 09/61] update release ci to not use setup.py --- .github/workflows/release.yaml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index f7d41a8..6a2bbf4 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -11,14 +11,17 @@ jobs: runs-on: ubuntu-22.04 if: success() && startsWith(github.ref, 'refs/tags') steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.7 - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 with: python-version: 3.7 - - name: Build source distribution + - name: Install dependencies run: | - python setup.py sdist + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build --sdist - name: Publish distribution to PyPI uses: pypa/gh-action-pypi-publish@master with: From aa7f40bf1cbde6a020187429049d2a8a4a710fd0 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 09:25:29 +0200 Subject: [PATCH 10/61] remove setup.py from readme --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 8d9e354..db563fc 100644 --- a/README.rst +++ b/README.rst @@ -104,7 +104,7 @@ We recommend that you reformat using `isort `_ a .. code:: bash - $ ssort src/ tests/; isort src/ tests/ setup.py; black src/ tests/ setup.py + $ ssort src/ tests/; isort src/ tests/; black src/ tests/ You can also setup ssort to run automatically before commit by setting up `pre-commit `_, and registering ssort in your `.pre-commit-config.yaml`. From b5456a87fe43001211e03a474dda1f7449cedc0d Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 10:31:30 +0200 Subject: [PATCH 11/61] introduce Config class, that parses pyproject.toml and iterates over all valid python files --- src/ssort/_config.py | 77 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 src/ssort/_config.py diff --git a/src/ssort/_config.py b/src/ssort/_config.py new file mode 100644 index 0000000..7c1f211 --- /dev/null +++ b/src/ssort/_config.py @@ -0,0 +1,77 @@ +import tomllib +from dataclasses import dataclass, field +from pathlib import Path + +DEFAULT_SKIP = frozenset( + { + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + } +) + + +def iter_valid_python_files_recursive(folder, *, is_invalid): + for child in folder.iterdir(): + if is_invalid(child): + continue + + if child.is_file() and child.suffix == ".py": + yield child + + if child.is_dir(): + yield from iter_valid_python_files_recursive( + child, is_invalid=is_invalid + ) + + +@dataclass(frozen=True, kw_only=True) +class Config: + skip: frozenset | list = DEFAULT_SKIP + extend_skip: list = field(default_factory=list) + root: Path + + def files(self): + invalid_names = set(self.skip) | set(self.extend_skip) + yield from iter_valid_python_files_recursive( + self.root, + is_invalid=lambda x: x.name in invalid_names, + ) + + +def parse_pyproject_toml(path): + with open(path, "rb") as fh: + pyproject_toml = tomllib.load(fh) + + config = pyproject_toml.get("tool", {}).get("ssort", {}) + config = {key.replace("-", "_"): val for key, val in config.items()} + + return config + + +def get_config_from_root(root): + path_pyproject_toml = root / "pyproject.toml" + + if path_pyproject_toml.exists(): + config_dict = parse_pyproject_toml(path_pyproject_toml) + else: + config_dict = {} + + return Config(**config_dict, root=root) From f6439b9af866081267e92c15e9d3e111d7beccd4 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 10:32:45 +0200 Subject: [PATCH 12/61] remove redundant find_python_files code, add find_project_root --- src/ssort/_files.py | 79 ++++++++------------------------------------- 1 file changed, 13 insertions(+), 66 deletions(-) diff --git a/src/ssort/_files.py b/src/ssort/_files.py index aca21ec..335ee6f 100644 --- a/src/ssort/_files.py +++ b/src/ssort/_files.py @@ -1,72 +1,19 @@ -from __future__ import annotations +from pathlib import Path -import os -import pathlib -from typing import Iterable -import pathspec - -from ssort._utils import memoize - -_EMPTY_PATH_SPEC = pathspec.PathSpec([]) - - -@memoize -def _is_project_root(path: pathlib.Path) -> bool: - if path == path.root or path == path.parent: - return True - - if (path / ".git").is_dir(): - return True - - return False - - -@memoize -def _get_ignore_patterns(path: pathlib.Path) -> pathspec.PathSpec: - git_ignore = path / ".gitignore" - if git_ignore.is_file(): - with git_ignore.open() as f: - return pathspec.PathSpec.from_lines("gitwildmatch", f) - - return _EMPTY_PATH_SPEC - - -def is_ignored(path: str | os.PathLike) -> bool: - # Can't use pathlib.Path.resolve() here because we want to maintain - # symbolic links. - path = pathlib.Path(os.path.abspath(path)) - - for part in (path, *path.parents): - patterns = _get_ignore_patterns(part) - if patterns.match_file(path.relative_to(part)): - return True - - if _is_project_root(part): - return False - - return False - - -def find_python_files( - patterns: Iterable[str | os.PathLike[str]], -) -> Iterable[pathlib.Path]: +def find_project_root(patterns): if not patterns: patterns = ["."] - paths_set = set() - for pattern in patterns: - path = pathlib.Path(pattern) - if not path.is_dir(): - subpaths = [path] - else: - subpaths = [ - subpath - for subpath in path.glob("**/*.py") - if not is_ignored(subpath) and subpath.is_file() - ] + paths = [Path(p).resolve() for p in patterns] + parents = [([p] if p.is_dir() else []) + list(p.parents) for p in paths] + + *_, (common_base, *_) = zip(*(reversed(p) for p in parents)) + + for directory in (common_base, *common_base.parents): + if (directory / ".git").exists() or ( + directory / "pyproject.toml" + ).is_file(): + return directory - for subpath in sorted(subpaths): - if subpath not in paths_set: - paths_set.add(subpath) - yield subpath + return directory From fa2b99fadcc7ac4214b95b67720ec15bf6e5bdb6 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 10:33:30 +0200 Subject: [PATCH 13/61] use new config to iterate over python files --- src/ssort/_main.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ssort/_main.py b/src/ssort/_main.py index a7efa1d..f376e83 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -4,7 +4,7 @@ import sys from ssort._exceptions import UnknownEncodingError -from ssort._files import find_python_files +from ssort._files import find_project_root from ssort._ssort import ssort from ssort._utils import ( detect_encoding, @@ -13,6 +13,8 @@ normalize_newlines, ) +from ._config import get_config_from_root + def main(): parser = argparse.ArgumentParser( @@ -42,7 +44,10 @@ def main(): unsortable = 0 unchanged = 0 - for path in find_python_files(args.files): + root = find_project_root(args.files) + config = get_config_from_root(root) + + for path in config.files(): errors = False try: From 193af08a7a204bb5f35e9c2ff90acf2f066cb33b Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 10:34:18 +0200 Subject: [PATCH 14/61] remove unnecessary error handling, both are handled by config.files() --- src/ssort/_main.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/ssort/_main.py b/src/ssort/_main.py index f376e83..c569456 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -52,14 +52,6 @@ def main(): try: original_bytes = path.read_bytes() - except FileNotFoundError: - sys.stderr.write(f"ERROR: {escape_path(path)} does not exist\n") - unsortable += 1 - continue - except IsADirectoryError: - sys.stderr.write(f"ERROR: {escape_path(path)} is a directory\n") - unsortable += 1 - continue except PermissionError: sys.stderr.write(f"ERROR: {escape_path(path)} is not readable\n") unsortable += 1 From 5080f51a6be1d33bfa7fb9873dae2093f3b661ac Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 10:35:26 +0200 Subject: [PATCH 15/61] add config for ssort --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5ab348e..3c62970 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,3 +75,6 @@ attr = "ssort.__version__" [tool.setuptools.packages.find] where = ["src"] + +[tool.ssort] +extend_skip = ['test_data'] From 3955ad8b4df485eda1cc70ea7c5c6a65228a995e Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Tue, 25 Apr 2023 12:03:14 +0200 Subject: [PATCH 16/61] update tox.ini for PEP517 compatibility --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 78ed059..5c52c21 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,6 @@ [tox] envlist = py38,py39,py310,black,isort,ssort,pyflakes,pylint,mypy +isolated_build = true [testenv] deps = From 3d55df17a405640d08866c51bab0103b96899ef1 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 11:44:33 +0200 Subject: [PATCH 17/61] update tests for _files.py --- tests/test_files.py | 247 +++++++++----------------------------------- 1 file changed, 46 insertions(+), 201 deletions(-) diff --git a/tests/test_files.py b/tests/test_files.py index 42a0c55..de75039 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1,223 +1,68 @@ -from __future__ import annotations - -import pathlib - import pytest -from ssort._files import is_ignored - - -def test_ignore_git( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.chdir(tmp_path) - - (tmp_path / ".git").mkdir() - (tmp_path / ".gitignore").write_text("ignored") - - assert not is_ignored("src") - assert not is_ignored("src/main.py") - - assert is_ignored("ignored") - assert is_ignored("ignored/main.py") - - assert is_ignored("src/ignored") - assert is_ignored("src/ignored/main.py") - - assert not is_ignored("../ignored") - assert not is_ignored("../ignored/main.py") - - assert is_ignored(f"../{tmp_path.name}/ignored") - assert is_ignored(f"../{tmp_path.name}/ignored/main.py") - - -def test_ignore_git_with_no_repo( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.chdir(tmp_path) - - (tmp_path / ".gitignore").write_text("ignored") - - assert not is_ignored("src") - assert not is_ignored("src/main.py") - - assert is_ignored("ignored") - assert is_ignored("ignored/main.py") - - assert is_ignored("src/ignored") - assert is_ignored("src/ignored/main.py") - - assert not is_ignored("../ignored") - assert not is_ignored("../ignored/main.py") - - assert is_ignored(f"../{tmp_path.name}/ignored") - assert is_ignored(f"../{tmp_path.name}/ignored/main.py") - - -def test_ignore_git_in_subdirectory( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.chdir(tmp_path) - - (tmp_path / ".git").mkdir() - (tmp_path / ".gitignore").write_text("parent") - - (tmp_path / "sub").mkdir() - (tmp_path / "sub" / ".gitignore").write_text("child") - - assert not is_ignored("src") - assert not is_ignored("src/main.py") - assert not is_ignored("sub/src") - assert not is_ignored("sub/src/main.py") - - assert is_ignored("parent") - assert is_ignored("parent/main.py") - assert is_ignored("sub/parent") - assert is_ignored("sub/parent/main.py") - - assert is_ignored("src/parent") - assert is_ignored("src/parent/main.py") - assert is_ignored("sub/src/parent") - assert is_ignored("sub/src/parent/main.py") - - assert not is_ignored("../parent") - assert not is_ignored("../parent/main.py") - assert not is_ignored("../sub/parent") - assert not is_ignored("../sub/parent/main.py") - - assert is_ignored(f"../{tmp_path.name}/parent") - assert is_ignored(f"../{tmp_path.name}/parent/main.py") - assert is_ignored(f"../{tmp_path.name}/sub/parent") - assert is_ignored(f"../{tmp_path.name}/sub/parent/main.py") - - assert not is_ignored("child") - assert not is_ignored("child/main.py") - assert is_ignored("sub/child") - assert is_ignored("sub/child/main.py") - - assert not is_ignored("src/child") - assert not is_ignored("src/child/main.py") - assert is_ignored("sub/src/child") - assert is_ignored("sub/src/child/main.py") - - assert not is_ignored("sub/../child") - assert not is_ignored("sub/../child/main.py") - - assert is_ignored(f"../{tmp_path.name}/sub/child") - assert is_ignored(f"../{tmp_path.name}/sub/child/main.py") - - -def test_ignore_git_in_working_subdirectory( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - (tmp_path / ".git").mkdir() - (tmp_path / ".gitignore").write_text("ignored") - - (tmp_path / "sub").mkdir() - monkeypatch.chdir(tmp_path / "sub") - - assert not is_ignored("src") - assert not is_ignored("src/main.py") - - assert is_ignored("ignored") - assert is_ignored("ignored/main.py") - - assert is_ignored("src/ignored") - assert is_ignored("src/ignored/main.py") - - assert is_ignored("../ignored") - assert is_ignored("../ignored/main.py") - - assert is_ignored("../sub/ignored") - assert is_ignored("../sub/ignored/main.py") - - assert not is_ignored("../../ignored") - assert not is_ignored("../../ignored/main.py") - - -def test_ignore_git_in_working_parent_directory( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.chdir(tmp_path) - - (tmp_path / "sub").mkdir() - (tmp_path / "sub" / ".git").mkdir() - (tmp_path / "sub" / ".gitignore").write_text("ignored") - - assert not is_ignored("ignored") - assert not is_ignored("ignored/main.py") - - assert is_ignored("sub/ignored") - assert is_ignored("sub/ignored/main.py") - - assert is_ignored("sub/src/ignored") - assert is_ignored("sub/src/ignored/main.py") - - assert not is_ignored("sub/../ignored") - assert not is_ignored("sub/../ignored/main.py") - +from ssort._files import find_project_root -def test_ignore_git_subdirectory_pattern( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.chdir(tmp_path) +subdirs = [ + (".",), + (".", "dir"), + ("dir", "dirA/B"), + ("dirA/B", "dir"), + ("a/a/a", "b/b/b"), + ("dir/a", "dir/b"), +] - (tmp_path / ".git").mkdir() - (tmp_path / ".gitignore").write_text("sub/ignored") - (tmp_path / "sub").mkdir() +@pytest.fixture() +def git(tmp_path): + root = tmp_path / "root" + (root / ".git").mkdir(parents=True) + return root - assert not is_ignored("sub") - assert not is_ignored("sub/main.py") - assert is_ignored("sub/ignored") - assert is_ignored("sub/ignored/main.py") +@pytest.mark.parametrize("subdir", subdirs) +def test_find_project_root_git(subdir, git): + print(subdir) + patterns = [git / sub for sub in subdir] + for p in patterns: + p.mkdir(parents=True, exist_ok=True) -def test_ignore_git_symlink_recursive( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.chdir(tmp_path) + assert git == find_project_root(patterns) - (tmp_path / ".git").mkdir() - (tmp_path / ".gitignore").write_text("ignored") - (tmp_path / "dir").mkdir() - (tmp_path / "dir" / "link").symlink_to(tmp_path / "dir") +@pytest.fixture() +def pyproject(tmp_path): + root = tmp_path / "root" + root.mkdir() + (root / "pyproject.toml").touch() + return root - assert not is_ignored("dir") - assert not is_ignored("dir/link") - assert not is_ignored("dir/link/link") - assert is_ignored("dir/ignored") - assert is_ignored("dir/link/ignored") - assert is_ignored("dir/link/link/ignored") +@pytest.mark.parametrize("subdir", subdirs) +def test_find_project_root_pyproject(subdir, pyproject): + patterns = [pyproject / sub for sub in subdir] + for p in patterns: + p.mkdir(parents=True, exist_ok=True) -def test_ignore_git_symlink_outside_repo( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.chdir(tmp_path) + assert pyproject == find_project_root(patterns) - (tmp_path / "repo" / ".git").mkdir(parents=True) - (tmp_path / "repo" / ".gitignore").write_text("link") - (tmp_path / "link").mkdir() - (tmp_path / "repo" / "link").symlink_to(tmp_path / "link") +@pytest.fixture() +def neither(tmp_path): + root = tmp_path / "root" + root.mkdir() + return root - assert not is_ignored("link") - assert not is_ignored("link/main.py") - assert is_ignored("repo/link") - assert is_ignored("repo/link/main.py") +@pytest.mark.parametrize("subdir", subdirs) +def test_find_project_root_neither(subdir, neither): + patterns = [neither / sub for sub in subdir] -def test_ignore_symlink_circular( - tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.chdir(tmp_path) + if all(s.startswith("dir/") for s in subdir): + neither = neither / "dir" - (tmp_path / "link1").symlink_to(tmp_path / "link2") - (tmp_path / "link2").symlink_to(tmp_path / "link1") + for p in patterns: + p.mkdir(parents=True, exist_ok=True) - assert not is_ignored("link1") - assert not is_ignored("link2") + assert neither == find_project_root(patterns) From 3b91f35547ebab4a91dd07314be99340f4b83910 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 11:45:39 +0200 Subject: [PATCH 18/61] fix bug find_project_root --- src/ssort/_files.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/ssort/_files.py b/src/ssort/_files.py index 335ee6f..9968582 100644 --- a/src/ssort/_files.py +++ b/src/ssort/_files.py @@ -6,9 +6,15 @@ def find_project_root(patterns): patterns = ["."] paths = [Path(p).resolve() for p in patterns] - parents = [([p] if p.is_dir() else []) + list(p.parents) for p in paths] + parents_and_self = [ + list(reversed(p.parents)) + ([p] if p.is_dir() else []) for p in paths + ] - *_, (common_base, *_) = zip(*(reversed(p) for p in parents)) + *_, (common_base, *_) = ( + common_parent + for same_lvl_parent in zip(*parents_and_self) + if len(common_parent := set(same_lvl_parent)) == 1 + ) for directory in (common_base, *common_base.parents): if (directory / ".git").exists() or ( @@ -16,4 +22,4 @@ def find_project_root(patterns): ).is_file(): return directory - return directory + return common_base From 6a659df91697a805db13730f30cb1ebd0407e6b8 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 11:46:34 +0200 Subject: [PATCH 19/61] use tomli for python < 3.11, tomlib for python >= 3.11 --- src/ssort/_config.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/ssort/_config.py b/src/ssort/_config.py index 7c1f211..3166d4c 100644 --- a/src/ssort/_config.py +++ b/src/ssort/_config.py @@ -1,7 +1,15 @@ -import tomllib +from __future__ import annotations + +import sys from dataclasses import dataclass, field from pathlib import Path +if sys.version_info >= (3, 11): + from tomllib import load +else: + from tomli import load + + DEFAULT_SKIP = frozenset( { ".bzr", @@ -42,11 +50,11 @@ def iter_valid_python_files_recursive(folder, *, is_invalid): ) -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class Config: + root: Path skip: frozenset | list = DEFAULT_SKIP extend_skip: list = field(default_factory=list) - root: Path def files(self): invalid_names = set(self.skip) | set(self.extend_skip) @@ -58,7 +66,7 @@ def files(self): def parse_pyproject_toml(path): with open(path, "rb") as fh: - pyproject_toml = tomllib.load(fh) + pyproject_toml = load(fh) config = pyproject_toml.get("tool", {}).get("ssort", {}) config = {key.replace("-", "_"): val for key, val in config.items()} From 797424bb763ba63f7d200c608c7dd4ba89ffff78 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 11:50:21 +0200 Subject: [PATCH 20/61] add tomli dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3c62970..b62b080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Topic :: Software Development :: Quality Assurance" ] dependencies = [ - "pathspec >=0.9.0" + "tomli; python_version<'3.11'" ] description = "The python statement sorter" dynamic = [ From cafa5bb3f83a68b753256562ec403eec7ef41d6a Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 11:52:03 +0200 Subject: [PATCH 21/61] remove pathspec mypy block since pathspec is not longer needed --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b62b080..aa1d5a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,10 +60,6 @@ profile = "black" [tool.mypy] exclude = "test_data/samples/*" -[[tool.mypy.overrides]] -ignore_missing_imports = true -module = "pathspec" - [tool.setuptools] include-package-data = false license-files = [ From 674ef1242c209bc2639942eab75125910efcc85b Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 12:02:58 +0200 Subject: [PATCH 22/61] remove test to check for not existent file, this cannot happen anymore --- tests/test_executable.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/tests/test_executable.py b/tests/test_executable.py index 0260d26..ec29d82 100644 --- a/tests/test_executable.py +++ b/tests/test_executable.py @@ -354,24 +354,10 @@ def test_ssort_empty_dir(ssort, tmp_path): assert (actual_msgs, actual_status) == (expected_msgs, expected_status) -def test_ssort_non_existent_file(ssort, tmp_path): - path = tmp_path / "file.py" - - expected_msgs = [ - f"ERROR: {escape_path(path)} does not exist\n", - "1 file was not sortable\n", - ] - expected_status = 1 - - actual_msgs, actual_status = ssort(path) - - assert (actual_msgs, actual_status) == (expected_msgs, expected_status) - - def test_ssort_no_py_extension(ssort, tmp_path): path = tmp_path / "file" path.write_bytes(_good) - expected_msgs = ["1 file was left unchanged\n"] + expected_msgs = ["No files are present to be sorted. Nothing to do.\n"] expected_status = 0 actual_msgs, actual_status = ssort(path) assert (actual_msgs, actual_status) == (expected_msgs, expected_status) From 91d8317b9bbbd35727e1f43828cc738ac37decd4 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 12:03:49 +0200 Subject: [PATCH 23/61] mock find_project_root so that pytest does not search for pyproject.toml beyond the specified test-folder --- tests/test_executable.py | 5 +++++ tox.ini | 1 + 2 files changed, 6 insertions(+) diff --git a/tests/test_executable.py b/tests/test_executable.py index ec29d82..64e6ae0 100644 --- a/tests/test_executable.py +++ b/tests/test_executable.py @@ -65,6 +65,11 @@ def _write_fixtures(dirpath, texts): return paths +@pytest.fixture(autouse=True) +def root(mocker, tmp_path): + mocker.patch("ssort._main.find_project_root", return_value=tmp_path) + + @pytest.fixture(params=["entrypoint", "module"]) def check(request): def _check(dirpath): diff --git a/tox.ini b/tox.ini index 5c52c21..50f654e 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ isolated_build = true [testenv] deps = pytest + pytest-mock commands = pytest -vv tests From 192d9ec09c47ab9bb31759ddd17e5c82cb9ec8db Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Wed, 26 Apr 2023 12:39:35 +0200 Subject: [PATCH 24/61] ssort only sorts the inteded files, not all files under root --- src/ssort/_config.py | 23 +++++++++++++++-------- src/ssort/_main.py | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/ssort/_config.py b/src/ssort/_config.py index 3166d4c..1b76d42 100644 --- a/src/ssort/_config.py +++ b/src/ssort/_config.py @@ -52,16 +52,23 @@ def iter_valid_python_files_recursive(folder, *, is_invalid): @dataclass(frozen=True) class Config: - root: Path skip: frozenset | list = DEFAULT_SKIP extend_skip: list = field(default_factory=list) - def files(self): - invalid_names = set(self.skip) | set(self.extend_skip) - yield from iter_valid_python_files_recursive( - self.root, - is_invalid=lambda x: x.name in invalid_names, - ) + def iterate_files_matching_pattern(self, pattern): + for pat in pattern: + path = Path(pat).resolve() + + if path.is_file() and path.suffix == ".py": + yield path + continue + + if path.is_dir(): + invalid_names = set(self.skip) | set(self.extend_skip) + yield from iter_valid_python_files_recursive( + path, + is_invalid=lambda x: x.name in invalid_names, + ) def parse_pyproject_toml(path): @@ -82,4 +89,4 @@ def get_config_from_root(root): else: config_dict = {} - return Config(**config_dict, root=root) + return Config(**config_dict) diff --git a/src/ssort/_main.py b/src/ssort/_main.py index c569456..9e1b4bd 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -47,7 +47,7 @@ def main(): root = find_project_root(args.files) config = get_config_from_root(root) - for path in config.files(): + for path in config.iterate_files_matching_pattern(args.files): errors = False try: From 27fee83d3e18dc063ed1464cd3a33e9d35b8d9cf Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 11:11:44 +0200 Subject: [PATCH 25/61] Revert "remove unnecessary error handling, both are handled by config.files()" This reverts commit 193af08a7a204bb5f35e9c2ff90acf2f066cb33b. --- src/ssort/_main.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/ssort/_main.py b/src/ssort/_main.py index 9e1b4bd..edb056a 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -52,6 +52,14 @@ def main(): try: original_bytes = path.read_bytes() + except FileNotFoundError: + sys.stderr.write(f"ERROR: {escape_path(path)} does not exist\n") + unsortable += 1 + continue + except IsADirectoryError: + sys.stderr.write(f"ERROR: {escape_path(path)} is a directory\n") + unsortable += 1 + continue except PermissionError: sys.stderr.write(f"ERROR: {escape_path(path)} is not readable\n") unsortable += 1 From 7a44e8a17dc1be9deb38d25f47d0cb07d86d342a Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 11:12:56 +0200 Subject: [PATCH 26/61] Revert "remove test to check for not existent file, this cannot happen anymore" This reverts commit 674ef1242c209bc2639942eab75125910efcc85b. --- tests/test_executable.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_executable.py b/tests/test_executable.py index 64e6ae0..6229c9e 100644 --- a/tests/test_executable.py +++ b/tests/test_executable.py @@ -359,10 +359,24 @@ def test_ssort_empty_dir(ssort, tmp_path): assert (actual_msgs, actual_status) == (expected_msgs, expected_status) +def test_ssort_non_existent_file(ssort, tmp_path): + path = tmp_path / "file.py" + + expected_msgs = [ + f"ERROR: {escape_path(path)} does not exist\n", + "1 file was not sortable\n", + ] + expected_status = 1 + + actual_msgs, actual_status = ssort(path) + + assert (actual_msgs, actual_status) == (expected_msgs, expected_status) + + def test_ssort_no_py_extension(ssort, tmp_path): path = tmp_path / "file" path.write_bytes(_good) - expected_msgs = ["No files are present to be sorted. Nothing to do.\n"] + expected_msgs = ["1 file was left unchanged\n"] expected_status = 0 actual_msgs, actual_status = ssort(path) assert (actual_msgs, actual_status) == (expected_msgs, expected_status) From 64ae5b601e9cfa741cc618ac49a5752fe6c47109 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 11:19:00 +0200 Subject: [PATCH 27/61] non existent and no py files are now handled by the config class and dont create an error msg --- tests/test_executable.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_executable.py b/tests/test_executable.py index 6229c9e..83eac71 100644 --- a/tests/test_executable.py +++ b/tests/test_executable.py @@ -362,11 +362,8 @@ def test_ssort_empty_dir(ssort, tmp_path): def test_ssort_non_existent_file(ssort, tmp_path): path = tmp_path / "file.py" - expected_msgs = [ - f"ERROR: {escape_path(path)} does not exist\n", - "1 file was not sortable\n", - ] - expected_status = 1 + expected_msgs = ["No files are present to be sorted. Nothing to do.\n"] + expected_status = 0 actual_msgs, actual_status = ssort(path) @@ -376,7 +373,7 @@ def test_ssort_non_existent_file(ssort, tmp_path): def test_ssort_no_py_extension(ssort, tmp_path): path = tmp_path / "file" path.write_bytes(_good) - expected_msgs = ["1 file was left unchanged\n"] + expected_msgs = ["No files are present to be sorted. Nothing to do.\n"] expected_status = 0 actual_msgs, actual_status = ssort(path) assert (actual_msgs, actual_status) == (expected_msgs, expected_status) From 7d57482fd3caf6a629c9255e727227fa56c70c17 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 11:20:07 +0200 Subject: [PATCH 28/61] FileNotFoundError case handled by config --- src/ssort/_main.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/ssort/_main.py b/src/ssort/_main.py index edb056a..8f59381 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -52,10 +52,6 @@ def main(): try: original_bytes = path.read_bytes() - except FileNotFoundError: - sys.stderr.write(f"ERROR: {escape_path(path)} does not exist\n") - unsortable += 1 - continue except IsADirectoryError: sys.stderr.write(f"ERROR: {escape_path(path)} is a directory\n") unsortable += 1 From 7e412748f04a694d278e0d9058b0436e225b33c9 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 11:20:37 +0200 Subject: [PATCH 29/61] path is directory is handled by dir iterator of config --- src/ssort/_main.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/ssort/_main.py b/src/ssort/_main.py index 8f59381..9e1b4bd 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -52,10 +52,6 @@ def main(): try: original_bytes = path.read_bytes() - except IsADirectoryError: - sys.stderr.write(f"ERROR: {escape_path(path)} is a directory\n") - unsortable += 1 - continue except PermissionError: sys.stderr.write(f"ERROR: {escape_path(path)} is not readable\n") unsortable += 1 From 31381d9725817aa87060e1cd767fa0db8f35d968 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 13:40:19 +0200 Subject: [PATCH 30/61] add tests for config --- src/ssort/_config.py | 21 ++--- src/ssort/_main.py | 2 +- tests/test_config.py | 194 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 13 deletions(-) create mode 100644 tests/test_config.py diff --git a/src/ssort/_config.py b/src/ssort/_config.py index 1b76d42..11948d2 100644 --- a/src/ssort/_config.py +++ b/src/ssort/_config.py @@ -41,10 +41,10 @@ def iter_valid_python_files_recursive(folder, *, is_invalid): if is_invalid(child): continue - if child.is_file() and child.suffix == ".py": + elif child.is_file() and child.suffix == ".py": yield child - if child.is_dir(): + elif child.is_dir(): yield from iter_valid_python_files_recursive( child, is_invalid=is_invalid ) @@ -55,19 +55,19 @@ class Config: skip: frozenset | list = DEFAULT_SKIP extend_skip: list = field(default_factory=list) - def iterate_files_matching_pattern(self, pattern): + def is_invalid(self, x): + return x.name in set(self.skip) | set(self.extend_skip) + + def iterate_files_matching_patterns(self, pattern): for pat in pattern: path = Path(pat).resolve() if path.is_file() and path.suffix == ".py": yield path - continue - if path.is_dir(): - invalid_names = set(self.skip) | set(self.extend_skip) + elif path.is_dir(): yield from iter_valid_python_files_recursive( - path, - is_invalid=lambda x: x.name in invalid_names, + path, is_invalid=self.is_invalid ) @@ -75,10 +75,7 @@ def parse_pyproject_toml(path): with open(path, "rb") as fh: pyproject_toml = load(fh) - config = pyproject_toml.get("tool", {}).get("ssort", {}) - config = {key.replace("-", "_"): val for key, val in config.items()} - - return config + return pyproject_toml.get("tool", {}).get("ssort", {}) def get_config_from_root(root): diff --git a/src/ssort/_main.py b/src/ssort/_main.py index 9e1b4bd..e6b1f54 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -47,7 +47,7 @@ def main(): root = find_project_root(args.files) config = get_config_from_root(root) - for path in config.iterate_files_matching_pattern(args.files): + for path in config.iterate_files_matching_patterns(args.files): errors = False try: diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..bffc805 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,194 @@ +from pathlib import Path + +import pytest + +import ssort._config as config + + +def test_default_skip_defined(): + assert hasattr(config, "DEFAULT_SKIP") + + +class TestIterValidPythonFiles: + @pytest.fixture() + def names(self): + return ["dog.py", "cat.py", "apple.py", "dir/meow.py"] + + @pytest.fixture() + def is_invalid(self, names): + def fun(name): + return name in names + + return fun + + @pytest.fixture() + def folder(self, tmp_path, names): + for name in names: + path = tmp_path / name + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + return tmp_path + + def test_iter_valid_python_files_recursive( + self, folder, is_invalid, names + ): + ret = list( + config.iter_valid_python_files_recursive( + folder, is_invalid=is_invalid + ) + ) + + assert len(ret) == len(names) + + for name in ret: + assert str(name.relative_to(folder)) in names + + def test_iter_valid_python_files_recursive_empty(self, tmp_path): + with pytest.raises(StopIteration): + next( + config.iter_valid_python_files_recursive( + tmp_path, is_invalid=lambda x: True + ) + ) + + +class TestConfig: + def test___init___default_values(self): + cfg = config.Config() + + assert hasattr(cfg, "skip") + assert hasattr(cfg, "extend_skip") + + assert isinstance(cfg.skip, frozenset) + assert len(cfg.skip) > 0 + + assert cfg.extend_skip == [] + + def test___init___overwrite_default(self): + cfg = config.Config(skip="no", extend_skip="nono") + + assert cfg.skip == "no" + assert cfg.extend_skip == "nono" + + def test_is_invalid(self): + skip = ["hello", "banana"] + extend_skip = ["world"] + + cfg = config.Config(skip=skip, extend_skip=extend_skip) + + skip.extend(extend_skip) + + for s in skip: + assert cfg.is_invalid(Path(s)) + assert cfg.is_invalid(Path("directory") / s) + assert not cfg.is_invalid(Path(s) / "directory") + + for s in ["apple", "bananas", "bananas/worlds", "Hello"]: + assert not cfg.is_invalid(Path(s)) + + @pytest.fixture() + def mock_iter(self, mocker): + return mocker.patch( + "ssort._config.iter_valid_python_files_recursive", + return_value=[None], + ) + + def test_iterate_files_matching_patterns_existing_python_file( + self, tmp_path, mock_iter + ): + path = tmp_path / "file.py" + path.touch() + + cfg = config.Config() + ret = cfg.iterate_files_matching_patterns([path]) + + assert list(ret) == [path] + mock_iter.assert_not_called() + + def test_iterate_files_matching_patterns_missing_python_file( + self, tmp_path, mock_iter + ): + path = tmp_path / "file.py" + + cfg = config.Config() + ret = cfg.iterate_files_matching_patterns([path]) + + assert list(ret) == [] + mock_iter.assert_not_called() + + def test_iterate_files_matching_patterns_dir(self, tmp_path, mock_iter): + path0 = tmp_path / "dir0" + path1 = tmp_path / "dir1" + path0.mkdir() + path1.mkdir() + + cfg = config.Config() + ret = cfg.iterate_files_matching_patterns([path0, path1]) + + assert list(ret) == [None, None] + mock_iter.assert_any_call(path0, is_invalid=cfg.is_invalid) + mock_iter.assert_any_call(path1, is_invalid=cfg.is_invalid) + assert mock_iter.call_count == 2 + + +@pytest.fixture() +def toml(tmp_path): + toml = """ + [tool.ssort] + banana = 'banana' + extend_skip = ['extend_skip'] + skip = ['skip'] + name-with-dash = '---' + + [tool.ssort.class] + sort_order = [] + name-with-dash = '---' + """ + + path = tmp_path / "test.tml" + with open(path, "w") as fh: + fh.write(toml) + + return path + + +def test_parse_pyproject_toml(toml, mocker): + ret = config.parse_pyproject_toml(toml) + + assert ret["banana"] == "banana" + assert ret["name-with-dash"] == "---" + assert ret["skip"] == ["skip"] + + assert ret["class"]["name-with-dash"] == "---" + assert ret["class"]["sort_order"] == [] + assert ret["extend_skip"] == ["extend_skip"] + + +def test_get_config_from_root_exists(mocker, tmp_path): + return_value = {"skyscraper": "tall"} + mock_parse = mocker.patch( + "ssort._config.parse_pyproject_toml", return_value=return_value + ) + mock_config = mocker.patch("ssort._config.Config") + + path = tmp_path / "pyproject.toml" + path.touch() + assert path.exists() + + config.get_config_from_root(path.parent) + + mock_parse.assert_called_once_with(path) + mock_config.assert_called_once_with(**return_value) + + +def test_get_config_from_root_not_exists(mocker, tmp_path): + mock_parse = mocker.patch("ssort._config.parse_pyproject_toml") + mock_config = mocker.patch("ssort._config.Config") + + path = tmp_path / "pyproject.toml" + assert not path.exists() + + config.get_config_from_root(path.parent) + + mock_parse.assert_not_called() + mock_config.assert_called_once_with(**{}) From de951ec43cfd7e50f1521d84dd996c70b2a68b56 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 13:43:12 +0200 Subject: [PATCH 31/61] update tox.ini to use new config --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 50f654e..710efd8 100644 --- a/tox.ini +++ b/tox.ini @@ -25,7 +25,7 @@ commands = [testenv:ssort] commands = - ssort --check --diff src/ssort tests + ssort --check --diff . [testenv:pyflakes] deps = From 9415c06a5f8b00946a051e15c23681c1753d0e4f Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 13:56:41 +0200 Subject: [PATCH 32/61] define pytest test path in pyproject.toml --- .github/workflows/ci.yaml | 4 ++-- pyproject.toml | 5 +++++ tox.ini | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8206e37..5635be6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,7 +27,7 @@ jobs: pip install -e .[test] - name: Run tests run: | - pytest -vv tests/ + pytest -vv coverage: name: "Coverage" @@ -46,7 +46,7 @@ jobs: pip install -e .[test] - name: Run tests run: | - pytest --cov=ssort -v tests/ + pytest --cov=ssort -v - name: Upload coverage report to coveralls run: | coveralls --service=github diff --git a/pyproject.toml b/pyproject.toml index 3c62970..0c45183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,11 @@ exclude = "test_data/samples/*" ignore_missing_imports = true module = "pathspec" +[tool.pytest.ini_options] +testpaths = [ + "tests" +] + [tool.setuptools] include-package-data = false license-files = [ diff --git a/tox.ini b/tox.ini index 1da2664..1218d12 100644 --- a/tox.ini +++ b/tox.ini @@ -33,7 +33,7 @@ deps = pyflakes skip_install = True commands = - pyflakes src/ssort tests + pyflakes src/ tests/ [testenv:pylint] deps = From f8e49deae15e85a789d644e8da17d7e0a8cf93cb Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 13:59:53 +0200 Subject: [PATCH 33/61] sync tox.ini and ci.yaml --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5635be6..39a9110 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -102,7 +102,7 @@ jobs: pip install -e . - name: Run ssort run: | - ssort --check --diff src/ tests/ + ssort --check --diff . pyflakes: name: "PyFlakes" From d65d30415fbfee2df0203065a7f3067512985ddc Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:51:12 +0200 Subject: [PATCH 34/61] export only get_config_from_root using __all__ --- src/ssort/_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ssort/_config.py b/src/ssort/_config.py index 11948d2..b82ecdf 100644 --- a/src/ssort/_config.py +++ b/src/ssort/_config.py @@ -10,6 +10,9 @@ from tomli import load +__all__ = ["get_config_from_root"] + + DEFAULT_SKIP = frozenset( { ".bzr", From e9689b5904b27f7feca430f59a81da547ff9595b Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:54:36 +0200 Subject: [PATCH 35/61] fix import statement to mirror the others --- src/ssort/_main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ssort/_main.py b/src/ssort/_main.py index e6b1f54..ee6cfd5 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -3,6 +3,7 @@ import re import sys +from ssort._config import get_config_from_root from ssort._exceptions import UnknownEncodingError from ssort._files import find_project_root from ssort._ssort import ssort @@ -13,8 +14,6 @@ normalize_newlines, ) -from ._config import get_config_from_root - def main(): parser = argparse.ArgumentParser( From 08cb4d7137bcfe2212ca7b00e0f47d1c6d3b4fb7 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 08:35:41 +0200 Subject: [PATCH 36/61] add current working dir to find_project_root to find the expected pyproject.toml --- src/ssort/_files.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/ssort/_files.py b/src/ssort/_files.py index 9968582..330f7c0 100644 --- a/src/ssort/_files.py +++ b/src/ssort/_files.py @@ -2,10 +2,12 @@ def find_project_root(patterns): - if not patterns: - patterns = ["."] + all_patterns = [Path(".").resolve()] - paths = [Path(p).resolve() for p in patterns] + if patterns: + all_patterns.extend(patterns) + + paths = [Path(p).resolve() for p in all_patterns] parents_and_self = [ list(reversed(p.parents)) + ([p] if p.is_dir() else []) for p in paths ] From 6bbf597f9b025335eb2b011f6fc550e227b03916 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 08:36:26 +0200 Subject: [PATCH 37/61] add skip_glob configuration key to filter files with glob pattern --- src/ssort/_config.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/ssort/_config.py b/src/ssort/_config.py index b82ecdf..e375a47 100644 --- a/src/ssort/_config.py +++ b/src/ssort/_config.py @@ -23,11 +23,13 @@ ".mypy_cache", ".nox", ".pants.d", + ".pytest_cache", ".pytype", ".ruff_cache", ".svn", ".tox", ".venv", + "__pycache__", "__pypackages__", "_build", "buck-out", @@ -56,10 +58,18 @@ def iter_valid_python_files_recursive(folder, *, is_invalid): @dataclass(frozen=True) class Config: skip: frozenset | list = DEFAULT_SKIP + skip_glob: list = field(default_factory=list) extend_skip: list = field(default_factory=list) - def is_invalid(self, x): - return x.name in set(self.skip) | set(self.extend_skip) + def is_invalid(self, path): + if path.name in (set(self.skip) | set(self.extend_skip)): + return True + + for pat in self.skip_glob: + if path.match(pat): + return True + + return False def iterate_files_matching_patterns(self, pattern): for pat in pattern: From ad3ff91a0b9103b39c516588d61da3c1fdb964e9 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 08:56:48 +0200 Subject: [PATCH 38/61] add tests for Config.is_invalid --- tests/test_config.py | 68 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index bffc805..7867c36 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,7 +12,31 @@ def test_default_skip_defined(): class TestIterValidPythonFiles: @pytest.fixture() def names(self): - return ["dog.py", "cat.py", "apple.py", "dir/meow.py"] + return [ + "apple.py", + "cat.py", + "cats.py", + "dir/bark.py", + "dir/meow.py", + "dog.py", + "cats/not_a_cat.py", + ] + + @pytest.fixture() + def skip_glob(self): + return ["dir/*", "cat*"] + + @pytest.fixture() + def is_invalid_glob(self, tmp_path): + def fun(path): + return path.relative_to(tmp_path) in [ + Path("cat.py"), + Path("cats.py"), + Path("dir/bark.py"), + Path("dir/meow.py"), + ] + + return fun @pytest.fixture() def is_invalid(self, names): @@ -29,6 +53,48 @@ def folder(self, tmp_path, names): path.touch() return tmp_path + @pytest.fixture() + def files(self, names, folder): + return [folder / name for name in names] + + def test_is_invalid_skip_only(self, names, folder, is_invalid, files): + cfg = config.Config(skip=names) + + valid = folder / "banana.py" + valid.touch() + files.append(valid) + + valid = folder / "dir" / "banana.py" + valid.touch() + files.append(valid) + + for file in files: + assert is_invalid(file.name) == cfg.is_invalid(file) + + def test_is_invalid_extend_skip_only( + self, names, folder, is_invalid, files + ): + cfg = config.Config(skip=[], extend_skip=names) + + valid = folder / "banana.py" + valid.touch() + files.append(valid) + + valid = folder / "dir" / "banana.py" + valid.touch() + files.append(valid) + + for file in files: + assert is_invalid(file.name) == cfg.is_invalid(file) + + def test_is_invalid_skip_glob_only( + self, skip_glob, is_invalid_glob, files + ): + cfg = config.Config(skip=[], skip_glob=skip_glob) + + for file in files: + assert is_invalid_glob(file) == cfg.is_invalid(file) + def test_iter_valid_python_files_recursive( self, folder, is_invalid, names ): From 8c049d1875e95a49766d2e93fbbdcd7975093ed6 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 09:09:07 +0200 Subject: [PATCH 39/61] fix a bug where glob pattern was incorrectly applied if path was not a file --- src/ssort/_config.py | 2 +- tests/test_config.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ssort/_config.py b/src/ssort/_config.py index e375a47..d68827f 100644 --- a/src/ssort/_config.py +++ b/src/ssort/_config.py @@ -66,7 +66,7 @@ def is_invalid(self, path): return True for pat in self.skip_glob: - if path.match(pat): + if path.is_file() and path.match(pat): return True return False diff --git a/tests/test_config.py b/tests/test_config.py index 7867c36..4a0722f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -94,6 +94,7 @@ def test_is_invalid_skip_glob_only( for file in files: assert is_invalid_glob(file) == cfg.is_invalid(file) + assert is_invalid_glob(file.parent) == cfg.is_invalid(file.parent) def test_iter_valid_python_files_recursive( self, folder, is_invalid, names From ee33e99da71ed69ea6506cec9fdcdbabd247d70d Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 09:19:17 +0200 Subject: [PATCH 40/61] create current_working_dir to make testing easier, add more tests --- src/ssort/_files.py | 8 ++- tests/test_files.py | 153 +++++++++++++++++++++++++------------------- 2 files changed, 94 insertions(+), 67 deletions(-) diff --git a/src/ssort/_files.py b/src/ssort/_files.py index 330f7c0..13b8ce8 100644 --- a/src/ssort/_files.py +++ b/src/ssort/_files.py @@ -1,8 +1,14 @@ from pathlib import Path +__all__ = ["find_project_root"] + + +def current_working_dir(): + return Path(".").resolve() + def find_project_root(patterns): - all_patterns = [Path(".").resolve()] + all_patterns = [current_working_dir()] if patterns: all_patterns.extend(patterns) diff --git a/tests/test_files.py b/tests/test_files.py index de75039..dcbef2e 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1,68 +1,89 @@ -import pytest - -from ssort._files import find_project_root - -subdirs = [ - (".",), - (".", "dir"), - ("dir", "dirA/B"), - ("dirA/B", "dir"), - ("a/a/a", "b/b/b"), - ("dir/a", "dir/b"), -] - - -@pytest.fixture() -def git(tmp_path): - root = tmp_path / "root" - (root / ".git").mkdir(parents=True) - return root - - -@pytest.mark.parametrize("subdir", subdirs) -def test_find_project_root_git(subdir, git): - print(subdir) - patterns = [git / sub for sub in subdir] - - for p in patterns: - p.mkdir(parents=True, exist_ok=True) - - assert git == find_project_root(patterns) - - -@pytest.fixture() -def pyproject(tmp_path): - root = tmp_path / "root" - root.mkdir() - (root / "pyproject.toml").touch() - return root +from pathlib import Path +import pytest -@pytest.mark.parametrize("subdir", subdirs) -def test_find_project_root_pyproject(subdir, pyproject): - patterns = [pyproject / sub for sub in subdir] - - for p in patterns: - p.mkdir(parents=True, exist_ok=True) - - assert pyproject == find_project_root(patterns) - - -@pytest.fixture() -def neither(tmp_path): - root = tmp_path / "root" - root.mkdir() - return root - - -@pytest.mark.parametrize("subdir", subdirs) -def test_find_project_root_neither(subdir, neither): - patterns = [neither / sub for sub in subdir] - - if all(s.startswith("dir/") for s in subdir): - neither = neither / "dir" - - for p in patterns: - p.mkdir(parents=True, exist_ok=True) - - assert neither == find_project_root(patterns) +from ssort._files import current_working_dir, find_project_root + + +def test_current_working_dir(): + assert current_working_dir() == Path(".").resolve() + + +class TestFindProjectRoot: + @pytest.fixture( + params=[ + (".",), + (".", "dir"), + ("dir", "dirA/B"), + ("dirA/B", "dir"), + ("a/a/a", "b/b/b"), + ("dir/a", "dir/b"), + ] + ) + def subdir(self, request): + return request.param + + @pytest.fixture() + def mock_current_working_dir(self, mocker, tmp_path): + return mocker.patch( + "ssort._files.current_working_dir", return_value=tmp_path / "root" + ) + + @pytest.fixture() + def git(self, tmp_path): + root = tmp_path / "root" + (root / ".git").mkdir(parents=True) + return root + + def test_find_project_root_git( + self, subdir, git, mock_current_working_dir + ): + print(subdir) + patterns = [git / sub for sub in subdir] + + for p in patterns: + p.mkdir(parents=True, exist_ok=True) + + assert git == find_project_root(patterns) + mock_current_working_dir.assert_called_once() + + @pytest.fixture() + def pyproject(self, tmp_path): + root = tmp_path / "root" + root.mkdir() + (root / "pyproject.toml").touch() + return root + + def test_find_project_root_pyproject( + self, subdir, pyproject, mock_current_working_dir + ): + patterns = [pyproject / sub for sub in subdir] + + for p in patterns: + p.mkdir(parents=True, exist_ok=True) + + assert pyproject == find_project_root(patterns) + mock_current_working_dir.assert_called_once() + + @pytest.fixture() + def neither(self, tmp_path): + root = tmp_path / "root" + root.mkdir() + return root + + def test_find_project_root_neither( + self, subdir, neither, mocker, tmp_path + ): + mocker.patch( + "ssort._files.current_working_dir", + return_value=tmp_path / "root" / "dir", + ) + patterns = [neither / sub for sub in subdir] + + if all(s.startswith("dir/") for s in subdir): + neither = neither / "dir" + + for p in patterns: + p.mkdir(parents=True, exist_ok=True) + + assert neither == find_project_root(patterns) From 77e6705f3495415b133f5055ccc85eb35a1b6b92 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 09:28:29 +0200 Subject: [PATCH 41/61] remove __pycache__ and .pytest_cache to match blacks defaults --- src/ssort/_config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ssort/_config.py b/src/ssort/_config.py index d68827f..f9921a0 100644 --- a/src/ssort/_config.py +++ b/src/ssort/_config.py @@ -23,13 +23,11 @@ ".mypy_cache", ".nox", ".pants.d", - ".pytest_cache", ".pytype", ".ruff_cache", ".svn", ".tox", ".venv", - "__pycache__", "__pypackages__", "_build", "buck-out", From 4440ec630d9c0d429f9ddbcaedcf8f20f8544404 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 09:30:45 +0200 Subject: [PATCH 42/61] ignore __pycache__ and .pytest_cache folders --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0c45183..d05790a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,4 +82,8 @@ attr = "ssort.__version__" where = ["src"] [tool.ssort] -extend_skip = ['test_data'] +extend_skip = [ + ".pytest_cache", + "__pycache__", + "test_data" +] From 7eb76c07b27a34aec0e9fe8c54a52b49274b2ad0 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:05:27 +0200 Subject: [PATCH 43/61] move sort_key_from_iter to _grahps.py --- src/ssort/_graphs.py | 7 +++++-- src/ssort/_ssort.py | 8 ++------ src/ssort/_utils.py | 6 ------ 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/ssort/_graphs.py b/src/ssort/_graphs.py index e9b8469..a00e552 100644 --- a/src/ssort/_graphs.py +++ b/src/ssort/_graphs.py @@ -2,8 +2,6 @@ from typing import Callable, Generic, Hashable, TypeVar -from ssort._utils import sort_key_from_iter - _T = TypeVar("_T", bound=Hashable) @@ -140,6 +138,11 @@ def is_topologically_sorted(nodes: list[_T], graph: Graph[_T]) -> bool: return True +def sort_key_from_iter(values): + index = {statement: index for index, statement in enumerate(values)} + return lambda value: index[value] + + def topological_sort( target: Graph[_T] | list[_T], /, *, graph: Graph[_T] | None = None ) -> list[_T]: diff --git a/src/ssort/_ssort.py b/src/ssort/_ssort.py index ce3282b..cf7bda0 100644 --- a/src/ssort/_ssort.py +++ b/src/ssort/_ssort.py @@ -17,15 +17,11 @@ from ssort._graphs import ( is_topologically_sorted, replace_cycles, + sort_key_from_iter, topological_sort, ) from ssort._parsing import parse, split_class -from ssort._utils import ( - detect_encoding, - detect_newline, - normalize_newlines, - sort_key_from_iter, -) +from ssort._utils import detect_encoding, detect_newline, normalize_newlines SPECIAL_PROPERTIES = [ "__doc__", diff --git a/src/ssort/_utils.py b/src/ssort/_utils.py index b98925b..4de7e60 100644 --- a/src/ssort/_utils.py +++ b/src/ssort/_utils.py @@ -16,12 +16,6 @@ memoize = functools.cache -def sort_key_from_iter(values): - index = {statement: index for index, statement in enumerate(values)} - key = lambda value: index[value] - return key - - _T = TypeVar("_T") From 014f6d1afe291677937aa95fe8f4f5e29db8b2cd Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:14:27 +0200 Subject: [PATCH 44/61] move single_dispatch to its own file --- src/ssort/_ast.py | 2 +- src/ssort/_bindings.py | 2 +- src/ssort/_method_requirements.py | 2 +- src/ssort/_requirements.py | 2 +- src/ssort/_single_dispatch.py | 33 +++++++++++++++++++++++++++++++ src/ssort/_utils.py | 26 +----------------------- 6 files changed, 38 insertions(+), 29 deletions(-) create mode 100644 src/ssort/_single_dispatch.py diff --git a/src/ssort/_ast.py b/src/ssort/_ast.py index 5288be8..9bd16cb 100644 --- a/src/ssort/_ast.py +++ b/src/ssort/_ast.py @@ -4,7 +4,7 @@ import sys from typing import Iterable -from ssort._utils import single_dispatch +from ssort._single_dispatch import single_dispatch @single_dispatch diff --git a/src/ssort/_bindings.py b/src/ssort/_bindings.py index c0476b3..27081ae 100644 --- a/src/ssort/_bindings.py +++ b/src/ssort/_bindings.py @@ -5,7 +5,7 @@ from typing import Iterable from ssort._ast import iter_child_nodes -from ssort._utils import single_dispatch +from ssort._single_dispatch import single_dispatch @single_dispatch diff --git a/src/ssort/_method_requirements.py b/src/ssort/_method_requirements.py index 73f72e6..d9f656a 100644 --- a/src/ssort/_method_requirements.py +++ b/src/ssort/_method_requirements.py @@ -4,7 +4,7 @@ from typing import Iterable from ssort._ast import iter_child_nodes -from ssort._utils import single_dispatch +from ssort._single_dispatch import single_dispatch @single_dispatch diff --git a/src/ssort/_requirements.py b/src/ssort/_requirements.py index 4700c43..89c2a53 100644 --- a/src/ssort/_requirements.py +++ b/src/ssort/_requirements.py @@ -8,7 +8,7 @@ from ssort._ast import iter_child_nodes from ssort._bindings import get_bindings from ssort._builtins import CLASS_BUILTINS -from ssort._utils import single_dispatch +from ssort._single_dispatch import single_dispatch class Scope(enum.Enum): diff --git a/src/ssort/_single_dispatch.py b/src/ssort/_single_dispatch.py new file mode 100644 index 0000000..3d2b989 --- /dev/null +++ b/src/ssort/_single_dispatch.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import functools +from typing import Any, Callable, Generic, TypeVar + +_T = TypeVar("_T") + + +__all__ = ["single_dispatch"] + + +class _SingleDispatch(Generic[_T]): + """A more performant implementation of functools.singledispatch.""" + + def __init__(self, function: Callable[..., _T]) -> None: + functools.update_wrapper(self, function) + self._function: Callable[..., _T] = function + self._functions: dict[type[Any], Callable[..., _T]] = {} + + def register( + self, cls: type[Any] + ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: + def decorator(function: Callable[..., _T]) -> Callable[..., _T]: + self._functions[cls] = function + return function + + return decorator + + def __call__(self, arg: Any, *args: Any) -> _T: + return self._functions.get(type(arg), self._function)(arg, *args) + + +single_dispatch = _SingleDispatch diff --git a/src/ssort/_utils.py b/src/ssort/_utils.py index 4de7e60..fbba5c3 100644 --- a/src/ssort/_utils.py +++ b/src/ssort/_utils.py @@ -6,7 +6,7 @@ import shlex import sys import tokenize -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Callable, TypeVar from ssort._exceptions import UnknownEncodingError @@ -19,30 +19,6 @@ _T = TypeVar("_T") -class _SingleDispatch(Generic[_T]): - """A more performant implementation of functools.singledispatch.""" - - def __init__(self, function: Callable[..., _T]) -> None: - functools.update_wrapper(self, function) - self._function: Callable[..., _T] = function - self._functions: dict[type[Any], Callable[..., _T]] = {} - - def register( - self, cls: type[Any] - ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: - def decorator(function: Callable[..., _T]) -> Callable[..., _T]: - self._functions[cls] = function - return function - - return decorator - - def __call__(self, arg: Any, *args: Any) -> _T: - return self._functions.get(type(arg), self._function)(arg, *args) - - -single_dispatch = _SingleDispatch - - def cached_method(function: Callable[[Any], _T]) -> Callable[[Any], _T]: cached_attribute_name = f"_{function.__name__}_cache" From e6a372a3d73a0a0daa2e23df290ae6c6b5aff146 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:23:51 +0200 Subject: [PATCH 45/61] move cached_method to _statements.py --- src/ssort/_statements.py | 23 +++++++++++++++++++++-- src/ssort/_utils.py | 17 +---------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/ssort/_statements.py b/src/ssort/_statements.py index 5ffdf52..3ce824e 100644 --- a/src/ssort/_statements.py +++ b/src/ssort/_statements.py @@ -1,12 +1,31 @@ from __future__ import annotations import ast -from typing import Iterable +import functools +from typing import Any, Callable, Iterable, TypeVar from ssort._bindings import get_bindings from ssort._method_requirements import get_method_requirements from ssort._requirements import Requirement, get_requirements -from ssort._utils import cached_method + +_T = TypeVar("_T") + +__all__ = ["Statement"] + + +def cached_method(function: Callable[[Any], _T]) -> Callable[[Any], _T]: + cached_attribute_name = f"_{function.__name__}_cache" + + @functools.wraps(function) + def wrapper(self) -> _T: + try: + return getattr(self, cached_attribute_name) + except AttributeError: + value = function(self) + setattr(self, cached_attribute_name, value) + return value + + return wrapper class Statement: diff --git a/src/ssort/_utils.py b/src/ssort/_utils.py index fbba5c3..690469a 100644 --- a/src/ssort/_utils.py +++ b/src/ssort/_utils.py @@ -6,7 +6,7 @@ import shlex import sys import tokenize -from typing import Any, Callable, TypeVar +from typing import TypeVar from ssort._exceptions import UnknownEncodingError @@ -19,21 +19,6 @@ _T = TypeVar("_T") -def cached_method(function: Callable[[Any], _T]) -> Callable[[Any], _T]: - cached_attribute_name = f"_{function.__name__}_cache" - - @functools.wraps(function) - def wrapper(self) -> _T: - try: - return getattr(self, cached_attribute_name) - except AttributeError: - value = function(self) - setattr(self, cached_attribute_name, value) - return value - - return wrapper - - def escape_path(path): """ Takes a `pathlib.Path` object and returns a string representation that can From 27ec33e2ab5eef926e0ff82679810e34a9dca0ff Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:31:23 +0200 Subject: [PATCH 46/61] move remaining functions to _files.py --- src/ssort/_files.py | 64 +++++++++++++++++++++++++++++++++++++- src/ssort/_main.py | 6 ++-- src/ssort/_ssort.py | 2 +- src/ssort/_utils.py | 66 ---------------------------------------- tests/test_executable.py | 2 +- 5 files changed, 68 insertions(+), 72 deletions(-) delete mode 100644 src/ssort/_utils.py diff --git a/src/ssort/_files.py b/src/ssort/_files.py index 13b8ce8..ac5efec 100644 --- a/src/ssort/_files.py +++ b/src/ssort/_files.py @@ -1,6 +1,24 @@ +from __future__ import annotations + +import io +import re +import shlex +import sys +import tokenize from pathlib import Path -__all__ = ["find_project_root"] +from ssort._exceptions import UnknownEncodingError + +__all__ = [ + "detect_encoding", + "detect_newline", + "escape_path", + "find_project_root", + "normalize_newlines", +] + + +_NEWLINE_RE = re.compile("(\r\n)|(\r)|(\n)") def current_working_dir(): @@ -31,3 +49,47 @@ def find_project_root(patterns): return directory return common_base + + +def escape_path(path): + """ + Takes a `pathlib.Path` object and returns a string representation that can + be safely copied into the system shell. + """ + if sys.platform == "win32": + # TODO + return str(path) + else: + return shlex.quote(str(path)) + + +def detect_encoding(bytestring): + """ + Detect the encoding of a python source file based on "coding" comments, as + defined in [PEP 263](https://www.python.org/dev/peps/pep-0263/). + """ + try: + encoding, _ = tokenize.detect_encoding(io.BytesIO(bytestring).readline) + except SyntaxError as exc: + raise UnknownEncodingError( + exc.msg, encoding=re.match("unknown encoding: (.*)", exc.msg)[1] + ) from exc + return encoding + + +def detect_newline(text): + """ + Detects the newline character used in a source file based on the first + occurence of '\\n', '\\r' or '\\r\\n'. + """ + match = re.search(_NEWLINE_RE, text) + if match is None: + return "\n" + return match[0] + + +def normalize_newlines(text): + """ + Replaces all occurrences of '\r' and '\\r\\n' with \n. + """ + return re.sub(_NEWLINE_RE, "\n", text) diff --git a/src/ssort/_main.py b/src/ssort/_main.py index ee6cfd5..4fd9afb 100644 --- a/src/ssort/_main.py +++ b/src/ssort/_main.py @@ -5,14 +5,14 @@ from ssort._config import get_config_from_root from ssort._exceptions import UnknownEncodingError -from ssort._files import find_project_root -from ssort._ssort import ssort -from ssort._utils import ( +from ssort._files import ( detect_encoding, detect_newline, escape_path, + find_project_root, normalize_newlines, ) +from ssort._ssort import ssort def main(): diff --git a/src/ssort/_ssort.py b/src/ssort/_ssort.py index cf7bda0..374a8ea 100644 --- a/src/ssort/_ssort.py +++ b/src/ssort/_ssort.py @@ -14,6 +14,7 @@ UnknownEncodingError, WildcardImportError, ) +from ssort._files import detect_encoding, detect_newline, normalize_newlines from ssort._graphs import ( is_topologically_sorted, replace_cycles, @@ -21,7 +22,6 @@ topological_sort, ) from ssort._parsing import parse, split_class -from ssort._utils import detect_encoding, detect_newline, normalize_newlines SPECIAL_PROPERTIES = [ "__doc__", diff --git a/src/ssort/_utils.py b/src/ssort/_utils.py deleted file mode 100644 index 690469a..0000000 --- a/src/ssort/_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import functools -import io -import re -import shlex -import sys -import tokenize -from typing import TypeVar - -from ssort._exceptions import UnknownEncodingError - -if sys.version_info < (3, 9): - memoize = functools.lru_cache(maxsize=None) -else: - memoize = functools.cache - - -_T = TypeVar("_T") - - -def escape_path(path): - """ - Takes a `pathlib.Path` object and returns a string representation that can - be safely copied into the system shell. - """ - if sys.platform == "win32": - # TODO - return str(path) - else: - return shlex.quote(str(path)) - - -def detect_encoding(bytestring): - """ - Detect the encoding of a python source file based on "coding" comments, as - defined in [PEP 263](https://www.python.org/dev/peps/pep-0263/). - """ - try: - encoding, _ = tokenize.detect_encoding(io.BytesIO(bytestring).readline) - except SyntaxError as exc: - raise UnknownEncodingError( - exc.msg, encoding=re.match("unknown encoding: (.*)", exc.msg)[1] - ) from exc - return encoding - - -_NEWLINE_RE = re.compile("(\r\n)|(\r)|(\n)") - - -def detect_newline(text): - """ - Detects the newline character used in a source file based on the first - occurence of '\\n', '\\r' or '\\r\\n'. - """ - match = re.search(_NEWLINE_RE, text) - if match is None: - return "\n" - return match[0] - - -def normalize_newlines(text): - """ - Replaces all occurrences of '\r' and '\\r\\n' with \n. - """ - return re.sub(_NEWLINE_RE, "\n", text) diff --git a/tests/test_executable.py b/tests/test_executable.py index 83eac71..94a73f8 100644 --- a/tests/test_executable.py +++ b/tests/test_executable.py @@ -4,7 +4,7 @@ import pytest -from ssort._utils import escape_path +from ssort._files import escape_path _good = b""" def _private(): From d4be66c8bb23852a47af3d57f7ca5fb1a2f508ce Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:14:27 +0200 Subject: [PATCH 47/61] move single_dispatch to its own file --- src/ssort/_utils.py | 81 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 src/ssort/_utils.py diff --git a/src/ssort/_utils.py b/src/ssort/_utils.py new file mode 100644 index 0000000..fbba5c3 --- /dev/null +++ b/src/ssort/_utils.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import functools +import io +import re +import shlex +import sys +import tokenize +from typing import Any, Callable, TypeVar + +from ssort._exceptions import UnknownEncodingError + +if sys.version_info < (3, 9): + memoize = functools.lru_cache(maxsize=None) +else: + memoize = functools.cache + + +_T = TypeVar("_T") + + +def cached_method(function: Callable[[Any], _T]) -> Callable[[Any], _T]: + cached_attribute_name = f"_{function.__name__}_cache" + + @functools.wraps(function) + def wrapper(self) -> _T: + try: + return getattr(self, cached_attribute_name) + except AttributeError: + value = function(self) + setattr(self, cached_attribute_name, value) + return value + + return wrapper + + +def escape_path(path): + """ + Takes a `pathlib.Path` object and returns a string representation that can + be safely copied into the system shell. + """ + if sys.platform == "win32": + # TODO + return str(path) + else: + return shlex.quote(str(path)) + + +def detect_encoding(bytestring): + """ + Detect the encoding of a python source file based on "coding" comments, as + defined in [PEP 263](https://www.python.org/dev/peps/pep-0263/). + """ + try: + encoding, _ = tokenize.detect_encoding(io.BytesIO(bytestring).readline) + except SyntaxError as exc: + raise UnknownEncodingError( + exc.msg, encoding=re.match("unknown encoding: (.*)", exc.msg)[1] + ) from exc + return encoding + + +_NEWLINE_RE = re.compile("(\r\n)|(\r)|(\n)") + + +def detect_newline(text): + """ + Detects the newline character used in a source file based on the first + occurence of '\\n', '\\r' or '\\r\\n'. + """ + match = re.search(_NEWLINE_RE, text) + if match is None: + return "\n" + return match[0] + + +def normalize_newlines(text): + """ + Replaces all occurrences of '\r' and '\\r\\n' with \n. + """ + return re.sub(_NEWLINE_RE, "\n", text) From a55658bfbb98bc2296e1e8895a03d0c2a54f7336 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:23:51 +0200 Subject: [PATCH 48/61] move cached_method to _statements.py --- src/ssort/_utils.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/ssort/_utils.py b/src/ssort/_utils.py index fbba5c3..690469a 100644 --- a/src/ssort/_utils.py +++ b/src/ssort/_utils.py @@ -6,7 +6,7 @@ import shlex import sys import tokenize -from typing import Any, Callable, TypeVar +from typing import TypeVar from ssort._exceptions import UnknownEncodingError @@ -19,21 +19,6 @@ _T = TypeVar("_T") -def cached_method(function: Callable[[Any], _T]) -> Callable[[Any], _T]: - cached_attribute_name = f"_{function.__name__}_cache" - - @functools.wraps(function) - def wrapper(self) -> _T: - try: - return getattr(self, cached_attribute_name) - except AttributeError: - value = function(self) - setattr(self, cached_attribute_name, value) - return value - - return wrapper - - def escape_path(path): """ Takes a `pathlib.Path` object and returns a string representation that can From 66e64c1939b20ad6b0dda3b09c8c404d4e9daaa6 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 17:31:23 +0200 Subject: [PATCH 49/61] move remaining functions to _files.py --- src/ssort/_utils.py | 66 --------------------------------------------- 1 file changed, 66 deletions(-) delete mode 100644 src/ssort/_utils.py diff --git a/src/ssort/_utils.py b/src/ssort/_utils.py deleted file mode 100644 index 690469a..0000000 --- a/src/ssort/_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import functools -import io -import re -import shlex -import sys -import tokenize -from typing import TypeVar - -from ssort._exceptions import UnknownEncodingError - -if sys.version_info < (3, 9): - memoize = functools.lru_cache(maxsize=None) -else: - memoize = functools.cache - - -_T = TypeVar("_T") - - -def escape_path(path): - """ - Takes a `pathlib.Path` object and returns a string representation that can - be safely copied into the system shell. - """ - if sys.platform == "win32": - # TODO - return str(path) - else: - return shlex.quote(str(path)) - - -def detect_encoding(bytestring): - """ - Detect the encoding of a python source file based on "coding" comments, as - defined in [PEP 263](https://www.python.org/dev/peps/pep-0263/). - """ - try: - encoding, _ = tokenize.detect_encoding(io.BytesIO(bytestring).readline) - except SyntaxError as exc: - raise UnknownEncodingError( - exc.msg, encoding=re.match("unknown encoding: (.*)", exc.msg)[1] - ) from exc - return encoding - - -_NEWLINE_RE = re.compile("(\r\n)|(\r)|(\n)") - - -def detect_newline(text): - """ - Detects the newline character used in a source file based on the first - occurence of '\\n', '\\r' or '\\r\\n'. - """ - match = re.search(_NEWLINE_RE, text) - if match is None: - return "\n" - return match[0] - - -def normalize_newlines(text): - """ - Replaces all occurrences of '\r' and '\\r\\n' with \n. - """ - return re.sub(_NEWLINE_RE, "\n", text) From c78111ab3e08e421be1192a806d9753e85c53f78 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 19:00:09 +0200 Subject: [PATCH 50/61] add new py311 exceptions --- src/ssort/_builtins.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ssort/_builtins.py b/src/ssort/_builtins.py index 52a0106..7a57a9f 100644 --- a/src/ssort/_builtins.py +++ b/src/ssort/_builtins.py @@ -17,6 +17,7 @@ "AssertionError", "AttributeError", "BaseException", + "BaseExceptionGroup", "BlockingIOError", "BrokenPipeError", "BufferError", @@ -32,6 +33,7 @@ "EncodingWarning", "EnvironmentError", "Exception", + "ExceptionGroup", "False", "FileExistsError", "FileNotFoundError", From 6a89beda7d8e8d5565ff8ad3f3d6836a53772add Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 19:00:42 +0200 Subject: [PATCH 51/61] sort exceptions alphabetically --- src/ssort/_builtins.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ssort/_builtins.py b/src/ssort/_builtins.py index 7a57a9f..f3bdb54 100644 --- a/src/ssort/_builtins.py +++ b/src/ssort/_builtins.py @@ -28,10 +28,10 @@ "ConnectionRefusedError", "ConnectionResetError", "DeprecationWarning", - "EOFError", "Ellipsis", "EncodingWarning", "EnvironmentError", + "EOFError", "Exception", "ExceptionGroup", "False", @@ -40,15 +40,15 @@ "FloatingPointError", "FutureWarning", "GeneratorExit", - "IOError", "ImportError", "ImportWarning", "IndentationError", "IndexError", "InterruptedError", + "IOError", "IsADirectoryError", - "KeyError", "KeyboardInterrupt", + "KeyError", "LookupError", "MemoryError", "ModuleNotFoundError", From 6d546c57c09b28f691662b2ce48440023ad1ec26 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 19:01:20 +0200 Subject: [PATCH 52/61] handle new ExceptionGroups i.e. ast.TryStar --- src/ssort/_ast.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ssort/_ast.py b/src/ssort/_ast.py index 9bd16cb..feb211d 100644 --- a/src/ssort/_ast.py +++ b/src/ssort/_ast.py @@ -142,7 +142,8 @@ def _iter_child_nodes_of_raise(node: ast.Raise) -> Iterable[ast.AST]: @iter_child_nodes.register(ast.Try) -def _iter_child_nodes_of_try(node: ast.Try) -> Iterable[ast.AST]: +@iter_child_nodes.register(ast.TryStar) +def _iter_child_nodes_of_try(node: ast.Try | ast.TryStar) -> Iterable[ast.AST]: yield from node.body yield from node.handlers yield from node.orelse From 37cde4b3c8f34d4603e28e80a04aee0d9b4d13e4 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 19:02:18 +0200 Subject: [PATCH 53/61] add tests for python 3.11 --- .github/workflows/ci.yaml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 39a9110..6ac17c6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ jobs: name: "Unit Tests" strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] runs-on: ["ubuntu-22.04", "windows-2019", "macos-11"] runs-on: ${{ matrix.runs-on }} steps: diff --git a/tox.ini b/tox.ini index 1218d12..7ce57b9 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py38,py39,py310,black,isort,ssort,pyflakes,pylint,mypy +envlist = py38,py39,py310,py311,black,isort,ssort,pyflakes,pylint,mypy isolated_build = true [testenv] From 095847364fb6682013172180179bafd4f9575de9 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 19:15:06 +0200 Subject: [PATCH 54/61] add version guard for ast.TryStar --- src/ssort/_ast.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/ssort/_ast.py b/src/ssort/_ast.py index feb211d..039ecf1 100644 --- a/src/ssort/_ast.py +++ b/src/ssort/_ast.py @@ -142,8 +142,7 @@ def _iter_child_nodes_of_raise(node: ast.Raise) -> Iterable[ast.AST]: @iter_child_nodes.register(ast.Try) -@iter_child_nodes.register(ast.TryStar) -def _iter_child_nodes_of_try(node: ast.Try | ast.TryStar) -> Iterable[ast.AST]: +def _iter_child_nodes_of_try(node: ast.Try) -> Iterable[ast.AST]: yield from node.body yield from node.handlers yield from node.orelse @@ -482,6 +481,16 @@ def _iter_child_nodes_of_match_or(node: ast.MatchOr) -> Iterable[ast.AST]: yield from node.patterns +if sys.version_info >= (3, 11): + + @iter_child_nodes.register(ast.TryStar) + def _iter_child_nodes_of_try(node: ast.TryStar) -> Iterable[ast.AST]: + yield from node.body + yield from node.handlers + yield from node.orelse + yield from node.finalbody + + @iter_child_nodes.register(ast.TypeIgnore) def _iter_child_nodes_of_type_ignore( node: ast.TypeIgnore, From f8d7e362b916f005774f78a24337bb1686139f40 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 19:20:47 +0200 Subject: [PATCH 55/61] add test for exception groups --- tests/test_requirements.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_requirements.py b/tests/test_requirements.py index 378ab8f..29bb044 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -11,6 +11,11 @@ reason="match statements were introduced in python 3.10", ) +exception_group = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="exception groups were introduced in python 3.11", +) + def _parse(source): source = textwrap.dedent(source) @@ -1157,3 +1162,36 @@ def test_match_statement_requirements_as(): """ ) assert _dep_names(node) == ["a"] + + +@exception_group +def test_try_star_requirements(): + """ + ..code:: python + + TryStar( + stmt* body, + excepthandler* handlers, + stmt* orelse, + stmt* finalbody, + ) + """ + node = _parse( + """ + try: + a = something_stupid() + except* ExceptionGroup as exc: + b = recover() + else: + c = otherwise() + finally: + d = finish() + """ + ) + assert _dep_names(node) == [ + "something_stupid", + "ExceptionGroup", + "recover", + "otherwise", + "finish", + ] From 3cc1ab90678d244dab2c85eeef5d1f51e117b263 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Sun, 30 Apr 2023 19:24:43 +0200 Subject: [PATCH 56/61] rename ast iter_child function --- src/ssort/_ast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ssort/_ast.py b/src/ssort/_ast.py index 039ecf1..2a048c9 100644 --- a/src/ssort/_ast.py +++ b/src/ssort/_ast.py @@ -484,7 +484,7 @@ def _iter_child_nodes_of_match_or(node: ast.MatchOr) -> Iterable[ast.AST]: if sys.version_info >= (3, 11): @iter_child_nodes.register(ast.TryStar) - def _iter_child_nodes_of_try(node: ast.TryStar) -> Iterable[ast.AST]: + def _iter_child_nodes_of_trystar(node: ast.TryStar) -> Iterable[ast.AST]: yield from node.body yield from node.handlers yield from node.orelse From 8137fa6bf3556f21d385d3daa25d6e1bd4de21d8 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 00:30:52 +0200 Subject: [PATCH 57/61] replace get_bindings with ast.NodeVisitor implementation --- src/ssort/_bindings.py | 166 +++++++++++++++++++---------------------- 1 file changed, 76 insertions(+), 90 deletions(-) diff --git a/src/ssort/_bindings.py b/src/ssort/_bindings.py index 27081ae..fb975f1 100644 --- a/src/ssort/_bindings.py +++ b/src/ssort/_bindings.py @@ -1,111 +1,97 @@ from __future__ import annotations import ast -import sys -from typing import Iterable - -from ssort._ast import iter_child_nodes -from ssort._single_dispatch import single_dispatch - - -@single_dispatch -def get_bindings(node: ast.AST) -> Iterable[str]: - for child in iter_child_nodes(node): - yield from get_bindings(child) - - -@get_bindings.register(ast.FunctionDef) -@get_bindings.register(ast.AsyncFunctionDef) -def _get_bindings_for_function_def( - node: ast.FunctionDef | ast.AsyncFunctionDef, -) -> Iterable[str]: - for decorator in node.decorator_list: - yield from get_bindings(decorator) - yield node.name - yield from get_bindings(node.args) - if node.returns is not None: - yield from get_bindings(node.returns) - - -@get_bindings.register(ast.ClassDef) -def _get_bindings_for_class_def(node: ast.ClassDef) -> Iterable[str]: - for decorator in node.decorator_list: - yield from get_bindings(decorator) - for base in node.bases: - yield from get_bindings(base) - for keyword in node.keywords: - yield from get_bindings(keyword.value) - yield node.name - - -@get_bindings.register(ast.Import) -def _get_bindings_for_import(node: ast.Import) -> Iterable[str]: - for name in node.names: - if name.asname: - yield name.asname +from typing import Sequence + +__all__ = ["get_bindings"] + + +class Bindings(ast.NodeVisitor): + def __init__(self): + self.stack = [] + + def append(self, name: list[str] | str | None): + if name is None: + return + + if isinstance(name, list): + self.stack.extend(name) + else: - root, *rest = name.name.split(".", 1) - yield root + self.stack.append(name) + def flexible_visit(self, node: Sequence[ast.AST] | ast.AST | None): + if node is None: + return + + if isinstance(node, Sequence): + for n in node: + self.flexible_visit(n) + else: + self.visit(node) -@get_bindings.register(ast.ImportFrom) -def _get_bindings_for_import_from(node: ast.ImportFrom) -> Iterable[str]: - for name in node.names: - yield name.asname if name.asname else name.name + def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef): + self.flexible_visit(node.decorator_list) + self.append(node.name) + # body is missing + self.flexible_visit(node.args) + self.flexible_visit(node.returns) + visit_AsyncFunctionDef = visit_FunctionDef -@get_bindings.register(ast.Global) -def _get_bindings_for_global(node: ast.Global) -> Iterable[str]: - yield from node.names + def visit_ClassDef(self, node: ast.ClassDef): + self.flexible_visit(node.decorator_list) + self.flexible_visit(node.bases) + self.flexible_visit(node.keywords) + self.append(node.name) + # missing keywords, starargs, body + def visit_Import(self, node): + for name in node.names: + if name.asname: + self.append(name.asname) + else: + root, *rest = name.name.split(".", 1) + self.append(root) -@get_bindings.register(ast.Nonlocal) -def _get_bindings_for_nonlocal(node: ast.Nonlocal) -> Iterable[str]: - yield from node.names + def visit_ImportFrom(self, node): + for name in node.names: + self.append(name.asname if name.asname else name.name) + def visit_Global(self, node: ast.Global | ast.Nonlocal): + self.append(node.names) -@get_bindings.register(ast.Lambda) -def _get_bindings_for_lambda(node: ast.Lambda) -> Iterable[str]: - yield from get_bindings(node.args) + visit_Nonlocal = visit_Global + def visit_Lambda(self, node): + self.flexible_visit(node.args) -@get_bindings.register(ast.Name) -def _get_bindings_for_name(node: ast.Name) -> Iterable[str]: - if isinstance(node.ctx, ast.Store): - yield node.id + # def visit_alias(self, node: ast.alias): + # self.append(node.asname if node.asname else node.name) + def visit_Name(self, node: ast.Name): + if isinstance(node.ctx, ast.Store): + self.append(node.id) -@get_bindings.register(ast.ExceptHandler) -def _get_bindings_for_except_handler(node: ast.ExceptHandler) -> Iterable[str]: - if node.type: - yield from get_bindings(node.type) - if node.name: - yield node.name - for statement in node.body: - yield from get_bindings(statement) + def visit_ExceptHandler(self, node: ast.ExceptHandler): + self.flexible_visit(node.type) + self.append(node.name) + self.flexible_visit(node.body) + def visit_MatchStar(self, node: ast.MatchStar): + self.append(node.name) -if sys.version_info >= (3, 10): + def visit_MatchMapping(self, node: ast.MatchMapping): + self.flexible_visit(node.keys) + self.flexible_visit(node.patterns) + self.append(node.rest) - @get_bindings.register(ast.MatchStar) - def _get_bindings_for_match_star(node: ast.MatchStar) -> Iterable[str]: - if node.name is not None: - yield node.name + def visit_MatchAs(self, node: ast.MatchAs): + self.flexible_visit(node.pattern) + self.append(node.name) - @get_bindings.register(ast.MatchMapping) - def _get_bindings_for_match_mapping( - node: ast.MatchMapping, - ) -> Iterable[str]: - for key in node.keys: - yield from get_bindings(key) - for pattern in node.patterns: - yield from get_bindings(pattern) - if node.rest is not None: - yield node.rest - @get_bindings.register(ast.MatchAs) - def _get_bindings_for_match_as(node: ast.MatchAs) -> Iterable[str]: - if node.pattern is not None: - yield from get_bindings(node.pattern) - if node.name is not None: - yield node.name +def get_bindings(node: ast.AST): + bindings = Bindings() + bindings.visit(node) + yield from bindings.stack From 24789d98ff7d4f029ca3604420e04f81c2c6cd90 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 00:37:24 +0200 Subject: [PATCH 58/61] replace get_method_requirements with ast.NodeVisitor implementation --- src/ssort/_method_requirements.py | 65 +++++++++++++------------------ 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/src/ssort/_method_requirements.py b/src/ssort/_method_requirements.py index d9f656a..3094acc 100644 --- a/src/ssort/_method_requirements.py +++ b/src/ssort/_method_requirements.py @@ -1,53 +1,44 @@ from __future__ import annotations import ast -from typing import Iterable -from ssort._ast import iter_child_nodes -from ssort._single_dispatch import single_dispatch +__all__ = ["get_method_requirements"] -@single_dispatch -def _get_attribute_accesses(node: ast.AST, variable: str) -> Iterable[str]: - for child in iter_child_nodes(node): - yield from _get_attribute_accesses(child, variable) +class SelfAccesses(ast.NodeVisitor): + def __init__(self, variable: str): + self.stack: list[str] = [] + self.variable = variable + def visit_ClassDef(self, _: ast.ClassDef): + # TODO + return -@_get_attribute_accesses.register(ast.ClassDef) -def _get_attribute_accesses_for_class_def( - node: ast.ClassDef, variable: str -) -> Iterable[str]: - # TODO - return () + def visit_Attribute(self, node: ast.Attribute): + if not isinstance(node.value, ast.Name): + super().visit(node.value) + elif isinstance(node.ctx, ast.Load) and node.value.id == self.variable: + self.stack.append(node.attr) -@_get_attribute_accesses.register(ast.Attribute) -def _get_attribute_accesses_for_attribute( - node: ast.Attribute, variable: str -) -> Iterable[str]: - yield from _get_attribute_accesses(node.value, variable) - if ( - isinstance(node.ctx, ast.Load) - and isinstance(node.value, ast.Name) - and node.value.id == variable - ): - yield node.attr +class MethodRequirements(ast.NodeVisitor): + def __init__(self) -> None: + self.stack: list[str] = [] + def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef): + if not node.args.args: + return -@single_dispatch -def get_method_requirements(node: ast.AST) -> Iterable[str]: - return () + self_arg = node.args.args[0].arg + self_access = SelfAccesses(self_arg) + self_access.visit(node) + self.stack.extend(self_access.stack) -@get_method_requirements.register(ast.FunctionDef) -@get_method_requirements.register(ast.AsyncFunctionDef) -def _get_method_requirements_for_function_def( - node: ast.FunctionDef | ast.AsyncFunctionDef, -) -> Iterable[str]: - if not node.args.args: - return + visit_AsyncFunctionDef = visit_FunctionDef - self_arg = node.args.args[0].arg - for statement in node.body: - yield from _get_attribute_accesses(statement, self_arg) +def get_method_requirements(node: ast.AST): + method_requirements = MethodRequirements() + method_requirements.visit(node) + yield from method_requirements.stack From bd9bb7a7fcbcabdedac8f79b38316eb92bdbca4f Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 01:17:59 +0200 Subject: [PATCH 59/61] replace get_requirements with ast.NodeVisitor implementation --- src/ssort/_requirements.py | 300 ++++++++++++++++++------------------- 1 file changed, 150 insertions(+), 150 deletions(-) diff --git a/src/ssort/_requirements.py b/src/ssort/_requirements.py index 89c2a53..15a85ea 100644 --- a/src/ssort/_requirements.py +++ b/src/ssort/_requirements.py @@ -3,12 +3,12 @@ import ast import dataclasses import enum -from typing import Iterable +from typing import Sequence -from ssort._ast import iter_child_nodes from ssort._bindings import get_bindings from ssort._builtins import CLASS_BUILTINS -from ssort._single_dispatch import single_dispatch + +__all__ = ["get_requirements", "Requirement"] class Scope(enum.Enum): @@ -26,13 +26,7 @@ class Requirement: scope: Scope = Scope.LOCAL -@single_dispatch -def get_requirements(node: ast.AST) -> Iterable[Requirement]: - for child in iter_child_nodes(node): - yield from get_requirements(child) - - -def _get_scope_from_arguments(args: ast.arguments) -> set[str]: +def get_scope_from_arguments(args: ast.arguments) -> set[str]: scope: set[str] = set() scope.update(arg.arg for arg in args.posonlyargs) scope.update(arg.arg for arg in args.args) # Arghhh. @@ -44,144 +38,150 @@ def _get_scope_from_arguments(args: ast.arguments) -> set[str]: return scope -@get_requirements.register(ast.FunctionDef) -@get_requirements.register(ast.AsyncFunctionDef) -def _get_requirements_for_function_def( - node: ast.FunctionDef | ast.AsyncFunctionDef, -) -> Iterable[Requirement]: - for decorator in node.decorator_list: - yield from get_requirements(decorator) - - yield from get_requirements(node.args) - - if node.returns is not None: - yield from get_requirements(node.returns) +def get_requirements(node: ast.AST): + requirements = Requirements() + requirements.visit(node) + yield from requirements.stack - scope = _get_scope_from_arguments(node.args) - - requirements = [] - for statement in node.body: - scope.update(get_bindings(statement)) - for requirement in get_requirements(statement): - if not requirement.deferred: - requirement = dataclasses.replace(requirement, deferred=True) - requirements.append(requirement) - - for requirement in requirements: - if requirement.scope == Scope.GLOBAL: - yield requirement - elif requirement.scope == Scope.NONLOCAL: - yield dataclasses.replace(requirement, scope=Scope.LOCAL) - elif requirement.name not in scope: - yield requirement - - -@get_requirements.register(ast.ClassDef) -def _get_requirements_for_class_def( - node: ast.ClassDef, -) -> Iterable[Requirement]: - for decorator in node.decorator_list: - yield from get_requirements(decorator) - - for base in node.bases: - yield from get_requirements(base) - - scope = set(CLASS_BUILTINS) - - for statement in node.body: - for stmt_dep in get_requirements(statement): - if stmt_dep.deferred or stmt_dep.name not in scope: - yield stmt_dep - - scope.update(get_bindings(statement)) - - -@get_requirements.register(ast.For) -@get_requirements.register(ast.AsyncFor) -def _get_requirements_for_for( - node: ast.For | ast.AsyncFor, -) -> Iterable[Requirement]: - bindings = set(get_bindings(node)) - - yield from get_requirements(node.target) - yield from get_requirements(node.iter) - - for stmt in node.body: - for requirement in get_requirements(stmt): - if requirement.name not in bindings: - yield requirement - - for stmt in node.orelse: - for requirement in get_requirements(stmt): - if requirement.name not in bindings: - yield requirement - - -@get_requirements.register(ast.With) -@get_requirements.register(ast.AsyncWith) -def _get_requirements_for_with( - node: ast.With | ast.AsyncWith, -) -> Iterable[Requirement]: - bindings = set(get_bindings(node)) - - for item in node.items: - yield from get_requirements(item) - - for stmt in node.body: - for requirement in get_requirements(stmt): - if requirement.name not in bindings: - yield requirement - - -@get_requirements.register(ast.Global) -def _get_requirements_for_global(node: ast.Global) -> Iterable[Requirement]: - for name in node.names: - yield Requirement( - name=name, - lineno=node.lineno, - col_offset=node.col_offset, - scope=Scope.GLOBAL, - ) - - -@get_requirements.register(ast.Nonlocal) -def _get_requirements_for_nonlocal( - node: ast.Nonlocal, -) -> Iterable[Requirement]: - for name in node.names: - yield Requirement( - name=name, - lineno=node.lineno, - col_offset=node.col_offset, - scope=Scope.NONLOCAL, - ) - - -@get_requirements.register(ast.Lambda) -def _get_requirements_for_lambda(node: ast.Lambda) -> Iterable[Requirement]: - yield from get_requirements(node.args) - scope = _get_scope_from_arguments(node.args) - scope.update(get_bindings(node.body)) - for requirement in get_requirements(node.body): - if requirement.name not in scope: - yield requirement - - -@get_requirements.register(ast.ListComp) -@get_requirements.register(ast.SetComp) -@get_requirements.register(ast.DictComp) -@get_requirements.register(ast.GeneratorExp) -def _get_requirements_for_comp(node: ast.AST) -> Iterable[Requirement]: - bindings = set(get_bindings(node)) - for child in iter_child_nodes(node): - for requirement in get_requirements(child): - if requirement.name not in bindings: - yield requirement - - -@get_requirements.register(ast.Name) -def _get_requirements_for_name(node: ast.Name) -> Iterable[Requirement]: - if isinstance(node.ctx, (ast.Load, ast.Del)): - yield Requirement( - name=node.id, lineno=node.lineno, col_offset=node.col_offset - ) + +class Requirements(ast.NodeVisitor): + def __init__(self): + self.stack = [] + + def flexible_visit(self, node: Sequence[ast.AST] | ast.AST | None): + if node is None: + return + + if isinstance(node, Sequence): + for n in node: + self.flexible_visit(n) + else: + self.visit(node) + + def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef): + self.flexible_visit(node.decorator_list) + self.flexible_visit(node.args) + self.flexible_visit(node.returns) + + scope = get_scope_from_arguments(node.args) + + requirements = [] + for statement in node.body: + scope.update(get_bindings(statement)) + for requirement in get_requirements(statement): + if not requirement.deferred: + requirement = dataclasses.replace( + requirement, deferred=True + ) + requirements.append(requirement) + + for requirement in requirements: + if requirement.scope == Scope.GLOBAL: + self.stack.append(requirement) + elif requirement.scope == Scope.NONLOCAL: + self.stack.append( + dataclasses.replace(requirement, scope=Scope.LOCAL) + ) + elif requirement.name not in scope: + self.stack.append(requirement) + + visit_AsyncFunctionDef = visit_FunctionDef + + def visit_ClassDef(self, node: ast.ClassDef): + self.flexible_visit(node.decorator_list) + self.flexible_visit(node.bases) + + scope = set(CLASS_BUILTINS) + + for statement in node.body: + for stmt_dep in get_requirements(statement): + if stmt_dep.deferred or stmt_dep.name not in scope: + self.stack.append(stmt_dep) + + scope.update(get_bindings(statement)) + + def visit_For(self, node: ast.For | ast.AsyncFor): + bindings = set(get_bindings(node)) + + self.flexible_visit(node.target) + self.flexible_visit(node.iter) + + for stmt in node.body: + for requirement in get_requirements(stmt): + if requirement.name not in bindings: + self.stack.append(requirement) + + for stmt in node.orelse: + for requirement in get_requirements(stmt): + if requirement.name not in bindings: + self.stack.append(requirement) + + visit_AsyncFor = visit_For + + def visit_With(self, node: ast.With | ast.AsyncWith): + bindings = set(get_bindings(node)) + + self.flexible_visit(node.items) + + for stmt in node.body: + for requirement in get_requirements(stmt): + if requirement.name not in bindings: + self.stack.append(requirement) + + visit_AsyncWith = visit_With + + def visit_Global(self, node: ast.Global): + for name in node.names: + self.stack.append( + Requirement( + name=name, + lineno=node.lineno, + col_offset=node.col_offset, + scope=Scope.GLOBAL, + ) + ) + + def visit_Nonlocal(self, node: ast.Nonlocal): + for name in node.names: + self.stack.append( + Requirement( + name=name, + lineno=node.lineno, + col_offset=node.col_offset, + scope=Scope.NONLOCAL, + ) + ) + + def visit_Lambda(self, node: ast.Lambda): + self.flexible_visit(node.args) + + scope = get_scope_from_arguments(node.args) + scope.update(get_bindings(node.body)) + + for requirement in get_requirements(node.body): + if requirement.name not in scope: + self.stack.append(requirement) + + def visit_ListComp( + self, + node: ast.ListComp | ast.SetComp | ast.DictComp | ast.GeneratorExp, + ): + bindings = set(get_bindings(node)) + for child in ast.iter_child_nodes(node): + for requirement in get_requirements(child): + if requirement.name not in bindings: + self.stack.append(requirement) + + visit_SetComp = visit_ListComp + visit_DictComp = visit_ListComp + visit_GeneratorExp = visit_ListComp + + def visit_Name(self, node: ast.Name): + if isinstance(node.ctx, (ast.Load, ast.Del)): + self.stack.append( + Requirement( + name=node.id, + lineno=node.lineno, + col_offset=node.col_offset, + ) + ) From ec26adfdf87ed305248e6015b2ef766b12cd7953 Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 01:19:06 +0200 Subject: [PATCH 60/61] remove redundant modules --- src/ssort/_ast.py | 498 ---------------------------------- src/ssort/_single_dispatch.py | 33 --- tests/test_ast.py | 73 ----- 3 files changed, 604 deletions(-) delete mode 100644 src/ssort/_ast.py delete mode 100644 src/ssort/_single_dispatch.py delete mode 100644 tests/test_ast.py diff --git a/src/ssort/_ast.py b/src/ssort/_ast.py deleted file mode 100644 index 2a048c9..0000000 --- a/src/ssort/_ast.py +++ /dev/null @@ -1,498 +0,0 @@ -from __future__ import annotations - -import ast -import sys -from typing import Iterable - -from ssort._single_dispatch import single_dispatch - - -@single_dispatch -def iter_child_nodes(node: ast.AST) -> Iterable[ast.AST]: - raise NotImplementedError( - f"AST traversal for {type(node).__name__!r} is not implemented" - ) - - -@iter_child_nodes.register(ast.Module) -def _iter_child_nodes_of_module(node: ast.Module) -> Iterable[ast.AST]: - yield from node.body - yield from node.type_ignores - - -@iter_child_nodes.register(ast.Interactive) -def _iter_child_nodes_of_interactive( - node: ast.Interactive, -) -> Iterable[ast.AST]: - yield from node.body - - -@iter_child_nodes.register(ast.Expression) -def _iter_child_nodes_of_expression(node: ast.Expression) -> Iterable[ast.AST]: - yield node.body - - -@iter_child_nodes.register(ast.FunctionType) -def _iter_child_nodes_of_function_type( - node: ast.FunctionType, -) -> Iterable[ast.AST]: - yield from node.argtypes - yield node.returns - - -@iter_child_nodes.register(ast.FunctionDef) -@iter_child_nodes.register(ast.AsyncFunctionDef) -def _iter_child_nodes_of_function_def( - node: ast.FunctionDef | ast.AsyncFunctionDef, -) -> Iterable[ast.AST]: - yield from node.decorator_list - yield node.args - if node.returns is not None: - yield node.returns - yield from node.body - - -@iter_child_nodes.register(ast.ClassDef) -def _iter_child_nodes_of_class_def(node: ast.ClassDef) -> Iterable[ast.AST]: - yield from node.decorator_list - yield from node.bases - yield from node.keywords - yield from node.body - - -@iter_child_nodes.register(ast.Return) -def _iter_child_nodes_of_return(node: ast.Return) -> Iterable[ast.AST]: - if node.value is not None: - yield node.value - - -@iter_child_nodes.register(ast.Delete) -def _iter_child_nodes_of_delete(node: ast.Delete) -> Iterable[ast.AST]: - yield from node.targets - - -@iter_child_nodes.register(ast.Assign) -def _iter_child_nodes_of_assign(node: ast.Assign) -> Iterable[ast.AST]: - yield from node.targets - yield node.value - - -@iter_child_nodes.register(ast.AugAssign) -def _iter_child_nodes_of_aug_assign(node: ast.AugAssign) -> Iterable[ast.AST]: - yield node.target - yield node.value - - -@iter_child_nodes.register(ast.AnnAssign) -def _iter_child_nodes_of_ann_assign(node: ast.AnnAssign) -> Iterable[ast.AST]: - yield node.target - yield node.annotation - if node.value is not None: - yield node.value - - -@iter_child_nodes.register(ast.For) -@iter_child_nodes.register(ast.AsyncFor) -def _iter_child_nodes_of_for( - node: ast.For | ast.AsyncFor, -) -> Iterable[ast.AST]: - yield node.target - yield node.iter - yield from node.body - yield from node.orelse - - -@iter_child_nodes.register(ast.While) -def _iter_child_nodes_of_while(node: ast.While) -> Iterable[ast.AST]: - yield node.test - yield from node.body - yield from node.orelse - - -@iter_child_nodes.register(ast.If) -def _iter_child_nodes_of_if(node: ast.If) -> Iterable[ast.AST]: - yield node.test - yield from node.body - yield from node.orelse - - -@iter_child_nodes.register(ast.With) -@iter_child_nodes.register(ast.AsyncWith) -def _iter_child_nodes_of_with( - node: ast.With | ast.AsyncWith, -) -> Iterable[ast.AST]: - yield from node.items - yield from node.body - - -if sys.version_info >= (3, 10): - - @iter_child_nodes.register(ast.Match) - def _iter_child_nodes_of_match(node: ast.Match) -> Iterable[ast.AST]: - yield node.subject - yield from node.cases - - -@iter_child_nodes.register(ast.Raise) -def _iter_child_nodes_of_raise(node: ast.Raise) -> Iterable[ast.AST]: - if node.exc is not None: - yield node.exc - if node.cause is not None: - yield node.cause - - -@iter_child_nodes.register(ast.Try) -def _iter_child_nodes_of_try(node: ast.Try) -> Iterable[ast.AST]: - yield from node.body - yield from node.handlers - yield from node.orelse - yield from node.finalbody - - -@iter_child_nodes.register(ast.Assert) -def _iter_child_nodes_of_assert(node: ast.Assert) -> Iterable[ast.AST]: - yield node.test - if node.msg is not None: - yield node.msg - - -@iter_child_nodes.register(ast.Import) -def _iter_child_nodes_of_import(node: ast.Import) -> Iterable[ast.AST]: - yield from node.names - - -@iter_child_nodes.register(ast.ImportFrom) -def _iter_child_nodes_of_import_from( - node: ast.ImportFrom, -) -> Iterable[ast.AST]: - yield from node.names - - -@iter_child_nodes.register(ast.Global) -@iter_child_nodes.register(ast.Nonlocal) -def _iter_child_nodes_of_scope( - node: ast.Global | ast.Nonlocal, -) -> Iterable[ast.AST]: - return () - - -@iter_child_nodes.register(ast.Expr) -def _iter_child_nodes_of_expr(node: ast.Expr) -> Iterable[ast.AST]: - yield node.value - - -@iter_child_nodes.register(ast.Pass) -@iter_child_nodes.register(ast.Break) -@iter_child_nodes.register(ast.Continue) -def _iter_child_nodes_of_control_flow( - node: ast.Pass | ast.Break | ast.Continue, -) -> Iterable[ast.AST]: - return () - - -@iter_child_nodes.register(ast.BoolOp) -def _iter_child_nodes_of_bool_op(node: ast.BoolOp) -> Iterable[ast.AST]: - yield from node.values - - -@iter_child_nodes.register(ast.NamedExpr) -def _iter_child_nodes_of_named_expr(node: ast.NamedExpr) -> Iterable[ast.AST]: - yield node.target - yield node.value - - -@iter_child_nodes.register(ast.BinOp) -def _iter_child_nodes_of_bin_op(node: ast.BinOp) -> Iterable[ast.AST]: - yield node.left - yield node.right - - -@iter_child_nodes.register(ast.UnaryOp) -def _iter_child_nodes_of_unary_op(node: ast.UnaryOp) -> Iterable[ast.AST]: - yield node.operand - - -@iter_child_nodes.register(ast.Lambda) -def _iter_child_nodes_of_lambda(node: ast.Lambda) -> Iterable[ast.AST]: - yield node.args - yield node.body - - -@iter_child_nodes.register(ast.IfExp) -def _iter_child_nodes_of_if_exp(node: ast.IfExp) -> Iterable[ast.AST]: - yield node.test - yield node.body - yield node.orelse - - -@iter_child_nodes.register(ast.Dict) -def _iter_child_nodes_of_dict(node: ast.Dict) -> Iterable[ast.AST]: - for key in node.keys: - if key is not None: - yield key - yield from node.values - - -@iter_child_nodes.register(ast.Set) -def _iter_child_nodes_of_set(node: ast.Set) -> Iterable[ast.AST]: - yield from node.elts - - -@iter_child_nodes.register(ast.ListComp) -def _iter_child_nodes_of_list_comp(node: ast.ListComp) -> Iterable[ast.AST]: - yield node.elt - yield from node.generators - - -@iter_child_nodes.register(ast.SetComp) -def _iter_child_nodes_of_set_comp(node: ast.SetComp) -> Iterable[ast.AST]: - yield node.elt - yield from node.generators - - -@iter_child_nodes.register(ast.DictComp) -def _iter_child_nodes_of_dict_comp(node: ast.DictComp) -> Iterable[ast.AST]: - yield node.key - yield node.value - yield from node.generators - - -@iter_child_nodes.register(ast.GeneratorExp) -def _iter_child_nodes_of_generator_exp( - node: ast.GeneratorExp, -) -> Iterable[ast.AST]: - yield node.elt - yield from node.generators - - -@iter_child_nodes.register(ast.Await) -def _iter_child_nodes_of_await(node: ast.Await) -> Iterable[ast.AST]: - yield node.value - - -@iter_child_nodes.register(ast.Yield) -def _iter_child_nodes_of_yield(node: ast.Yield) -> Iterable[ast.AST]: - if node.value is not None: - yield node.value - - -@iter_child_nodes.register(ast.YieldFrom) -def _iter_child_nodes_of_yield_from(node: ast.YieldFrom) -> Iterable[ast.AST]: - yield node.value - - -@iter_child_nodes.register(ast.Compare) -def _iter_child_nodes_of_compare(node: ast.Compare) -> Iterable[ast.AST]: - yield node.left - yield from node.comparators - - -@iter_child_nodes.register(ast.Call) -def _iter_child_nodes_of_call(node: ast.Call) -> Iterable[ast.AST]: - yield node.func - yield from node.args - yield from node.keywords - - -@iter_child_nodes.register(ast.FormattedValue) -def _iter_child_nodes_of_formatted_value( - node: ast.FormattedValue, -) -> Iterable[ast.AST]: - yield node.value - if node.format_spec is not None: - yield node.format_spec - - -@iter_child_nodes.register(ast.JoinedStr) -def _iter_child_nodes_of_joined_str(node: ast.JoinedStr) -> Iterable[ast.AST]: - yield from node.values - - -@iter_child_nodes.register(ast.Constant) -def _iter_child_nodes_of_constant(node: ast.Constant) -> Iterable[ast.AST]: - return () - - -@iter_child_nodes.register(ast.Attribute) -def _iter_child_nodes_of_attribute(node: ast.Attribute) -> Iterable[ast.AST]: - yield node.value - - -@iter_child_nodes.register(ast.Subscript) -def _iter_child_nodes_of_subscript(node: ast.Subscript) -> Iterable[ast.AST]: - yield node.value - yield node.slice - - -@iter_child_nodes.register(ast.Starred) -def _iter_child_nodes_of_starred(node: ast.Starred) -> Iterable[ast.AST]: - yield node.value - - -@iter_child_nodes.register(ast.Name) -def _iter_child_nodes_of_name(node: ast.Name) -> Iterable[ast.AST]: - return () - - -@iter_child_nodes.register(ast.List) -@iter_child_nodes.register(ast.Tuple) -def _iter_child_nodes_of_sequence( - node: ast.List | ast.Tuple, -) -> Iterable[ast.AST]: - yield from node.elts - - -@iter_child_nodes.register(ast.Slice) -def _iter_child_nodes_of_slice(node: ast.Slice) -> Iterable[ast.AST]: - if node.lower is not None: - yield node.lower - if node.upper is not None: - yield node.upper - if node.step is not None: - yield node.step - - -if sys.version_info < (3, 9): - - @iter_child_nodes.register(ast.ExtSlice) - def _iter_child_nodes_of_ext_slice( - node: ast.ExtSlice, - ) -> Iterable[ast.AST]: - yield from node.dims - - @iter_child_nodes.register(ast.Index) - def _iter_child_nodes_of_index(node: ast.Index) -> Iterable[ast.AST]: - yield node.value - - -@iter_child_nodes.register(ast.comprehension) -def _iter_child_nodes_of_comprehension( - node: ast.comprehension, -) -> Iterable[ast.AST]: - yield node.target - yield node.iter - yield from node.ifs - - -@iter_child_nodes.register(ast.ExceptHandler) -def _iter_child_nodes_of_except_handler( - node: ast.ExceptHandler, -) -> Iterable[ast.AST]: - if node.type is not None: - yield node.type - yield from node.body - - -@iter_child_nodes.register(ast.arguments) -def _iter_child_nodes_of_arguments(node: ast.arguments) -> Iterable[ast.AST]: - yield from node.posonlyargs - yield from node.args - if node.vararg is not None: - yield node.vararg - yield from node.kwonlyargs - for default in node.kw_defaults: - if default is not None: - yield default - if node.kwarg is not None: - yield node.kwarg - yield from node.defaults - - -@iter_child_nodes.register(ast.arg) -def _iter_child_nodes_of_arg(node: ast.arg) -> Iterable[ast.AST]: - if node.annotation is not None: - yield node.annotation - - -@iter_child_nodes.register(ast.keyword) -def _iter_child_nodes_of_keyword(node: ast.keyword) -> Iterable[ast.AST]: - yield node.value - - -@iter_child_nodes.register(ast.alias) -def _iter_child_nodes_of_alias(node: ast.alias) -> Iterable[ast.AST]: - return () - - -@iter_child_nodes.register(ast.withitem) -def _iter_child_nodes_of_withitem(node: ast.withitem) -> Iterable[ast.AST]: - yield node.context_expr - if node.optional_vars is not None: - yield node.optional_vars - - -if sys.version_info >= (3, 10): - - @iter_child_nodes.register(ast.match_case) - def _iter_child_nodes_of_match_case( - node: ast.match_case, - ) -> Iterable[ast.AST]: - yield node.pattern - if node.guard is not None: - yield node.guard - yield from node.body - - @iter_child_nodes.register(ast.MatchValue) - def _iter_child_nodes_of_match_value( - node: ast.MatchValue, - ) -> Iterable[ast.AST]: - yield node.value - - @iter_child_nodes.register(ast.MatchSingleton) - def _iter_child_nodes_of_match_singleton( - node: ast.MatchSingleton, - ) -> Iterable[ast.AST]: - return () - - @iter_child_nodes.register(ast.MatchSequence) - def _iter_child_nodes_of_match_sequence( - node: ast.MatchSequence, - ) -> Iterable[ast.AST]: - yield from node.patterns - - @iter_child_nodes.register(ast.MatchMapping) - def _iter_child_nodes_of_match_mapping( - node: ast.MatchMapping, - ) -> Iterable[ast.AST]: - yield from node.keys - yield from node.patterns - - @iter_child_nodes.register(ast.MatchClass) - def _iter_child_nodes_of_match_class( - node: ast.MatchClass, - ) -> Iterable[ast.AST]: - yield node.cls - yield from node.patterns - yield from node.kwd_patterns - - @iter_child_nodes.register(ast.MatchStar) - def _iter_child_nodes_of_match_star( - node: ast.MatchStar, - ) -> Iterable[ast.AST]: - return () - - @iter_child_nodes.register(ast.MatchAs) - def _iter_child_nodes_of_match_as(node: ast.MatchAs) -> Iterable[ast.AST]: - if node.pattern is not None: - yield node.pattern - - @iter_child_nodes.register(ast.MatchOr) - def _iter_child_nodes_of_match_or(node: ast.MatchOr) -> Iterable[ast.AST]: - yield from node.patterns - - -if sys.version_info >= (3, 11): - - @iter_child_nodes.register(ast.TryStar) - def _iter_child_nodes_of_trystar(node: ast.TryStar) -> Iterable[ast.AST]: - yield from node.body - yield from node.handlers - yield from node.orelse - yield from node.finalbody - - -@iter_child_nodes.register(ast.TypeIgnore) -def _iter_child_nodes_of_type_ignore( - node: ast.TypeIgnore, -) -> Iterable[ast.AST]: - return () diff --git a/src/ssort/_single_dispatch.py b/src/ssort/_single_dispatch.py deleted file mode 100644 index 3d2b989..0000000 --- a/src/ssort/_single_dispatch.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -import functools -from typing import Any, Callable, Generic, TypeVar - -_T = TypeVar("_T") - - -__all__ = ["single_dispatch"] - - -class _SingleDispatch(Generic[_T]): - """A more performant implementation of functools.singledispatch.""" - - def __init__(self, function: Callable[..., _T]) -> None: - functools.update_wrapper(self, function) - self._function: Callable[..., _T] = function - self._functions: dict[type[Any], Callable[..., _T]] = {} - - def register( - self, cls: type[Any] - ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: - def decorator(function: Callable[..., _T]) -> Callable[..., _T]: - self._functions[cls] = function - return function - - return decorator - - def __call__(self, arg: Any, *args: Any) -> _T: - return self._functions.get(type(arg), self._function)(arg, *args) - - -single_dispatch = _SingleDispatch diff --git a/tests/test_ast.py b/tests/test_ast.py deleted file mode 100644 index 77c445f..0000000 --- a/tests/test_ast.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import annotations - -import ast -import sys -from typing import Iterable - -import pytest - -from ssort._ast import iter_child_nodes - -_deprecated_node_types: tuple[type[ast.AST], ...] = ( - ast.AugLoad, - ast.AugStore, - ast.Param, - ast.Suite, -) - -if sys.version_info >= (3, 9): - _deprecated_node_types += (ast.Index, ast.ExtSlice) - -_ignored_node_types: tuple[type[ast.AST], ...] = ( - ast.expr_context, - ast.boolop, - ast.operator, - ast.unaryop, - ast.cmpop, -) - - -def _nodes_types( - node_type: type[ast.AST] = ast.AST, -) -> Iterable[type[ast.AST]]: - # Skip deprecated AST nodes. - if issubclass(node_type, _deprecated_node_types): - return - - # Skip ignored AST nodes. - if issubclass(node_type, _ignored_node_types): - return - - subclasses = node_type.__subclasses__() - if subclasses: - # Note that we do not yield the node_type if it has any subclasses. - # This is because AST base classes are used for categorical purposes - # only and are not intended to be instantiated. - for subclass in subclasses: - yield from _nodes_types(subclass) - else: - yield node_type - - -def _instantiate_node(node_type: type[ast.AST]) -> ast.AST: - # AST node fields are either strings or iterables of child AST nodes. The - # empty string satisfies both those requirements. - return node_type(*([""] * len(node_type._fields))) - - -def parametrize_nodes() -> pytest.MarkDecorator: - node_types = list(_nodes_types()) - nodes = [_instantiate_node(node_type) for node_type in node_types] - ids = [node_type.__name__ for node_type in node_types] - - return pytest.mark.parametrize("node", nodes, ids=ids) - - -def test_iter_child_nodes_is_not_implemented_for_none() -> None: - with pytest.raises(NotImplementedError): - iter_child_nodes(None) - - -@parametrize_nodes() -def test_iter_child_nodes_is_implemented(node: ast.AST) -> None: - list(iter_child_nodes(node)) From a1ababeea3fa75c1d3d453a0351943b749a63e4f Mon Sep 17 00:00:00 2001 From: Constantin Gahr Date: Mon, 1 May 2023 09:48:06 +0200 Subject: [PATCH 61/61] extract smart NodeVisitor logic to SmartNodeVisitor --- src/ssort/_bindings.py | 75 ++++++++++++++------------------------ src/ssort/_node_visitor.py | 27 ++++++++++++++ src/ssort/_requirements.py | 35 ++++++------------ 3 files changed, 66 insertions(+), 71 deletions(-) create mode 100644 src/ssort/_node_visitor.py diff --git a/src/ssort/_bindings.py b/src/ssort/_bindings.py index fb975f1..8e9822b 100644 --- a/src/ssort/_bindings.py +++ b/src/ssort/_bindings.py @@ -3,92 +3,71 @@ import ast from typing import Sequence -__all__ = ["get_bindings"] - - -class Bindings(ast.NodeVisitor): - def __init__(self): - self.stack = [] - - def append(self, name: list[str] | str | None): - if name is None: - return +from ._node_visitor import SmartNodeVisitor - if isinstance(name, list): - self.stack.extend(name) - - else: - self.stack.append(name) - - def flexible_visit(self, node: Sequence[ast.AST] | ast.AST | None): - if node is None: - return +__all__ = ["get_bindings"] - if isinstance(node, Sequence): - for n in node: - self.flexible_visit(n) - else: - self.visit(node) +class Bindings(SmartNodeVisitor): def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef): - self.flexible_visit(node.decorator_list) - self.append(node.name) + self.smart_visit(node.decorator_list) + self.smart_append(node.name) # body is missing - self.flexible_visit(node.args) - self.flexible_visit(node.returns) + self.smart_visit(node.args) + self.smart_visit(node.returns) visit_AsyncFunctionDef = visit_FunctionDef def visit_ClassDef(self, node: ast.ClassDef): - self.flexible_visit(node.decorator_list) - self.flexible_visit(node.bases) - self.flexible_visit(node.keywords) - self.append(node.name) + self.smart_visit(node.decorator_list) + self.smart_visit(node.bases) + self.smart_visit(node.keywords) + self.smart_append(node.name) # missing keywords, starargs, body def visit_Import(self, node): for name in node.names: if name.asname: - self.append(name.asname) + self.smart_append(name.asname) else: root, *rest = name.name.split(".", 1) - self.append(root) + self.smart_append(root) def visit_ImportFrom(self, node): for name in node.names: - self.append(name.asname if name.asname else name.name) + self.smart_append(name.asname if name.asname else name.name) def visit_Global(self, node: ast.Global | ast.Nonlocal): - self.append(node.names) + self.smart_append(node.names) visit_Nonlocal = visit_Global def visit_Lambda(self, node): - self.flexible_visit(node.args) + self.smart_visit(node.args) # def visit_alias(self, node: ast.alias): - # self.append(node.asname if node.asname else node.name) + # self.smart_append(node.asname if node.asname else node.name) def visit_Name(self, node: ast.Name): if isinstance(node.ctx, ast.Store): - self.append(node.id) + self.smart_append(node.id) def visit_ExceptHandler(self, node: ast.ExceptHandler): - self.flexible_visit(node.type) - self.append(node.name) - self.flexible_visit(node.body) + self.smart_visit(node.type) + self.smart_append(node.name) + self.smart_visit(node.body) def visit_MatchStar(self, node: ast.MatchStar): - self.append(node.name) + self.smart_append(node.name) def visit_MatchMapping(self, node: ast.MatchMapping): - self.flexible_visit(node.keys) - self.flexible_visit(node.patterns) - self.append(node.rest) + self.smart_visit(node.keys) + self.smart_visit(node.patterns) + self.smart_append(node.rest) def visit_MatchAs(self, node: ast.MatchAs): - self.flexible_visit(node.pattern) - self.append(node.name) + self.smart_visit(node.pattern) + self.smart_append(node.name) def get_bindings(node: ast.AST): diff --git a/src/ssort/_node_visitor.py b/src/ssort/_node_visitor.py new file mode 100644 index 0000000..9ef81d0 --- /dev/null +++ b/src/ssort/_node_visitor.py @@ -0,0 +1,27 @@ +import ast +from typing import Sequence + + +class SmartNodeVisitor(ast.NodeVisitor): + def __init__(self): + self.stack = [] + + def smart_append(self, name: list[str] | str | None): + if name is None: + return + + if isinstance(name, list): + self.stack.extend(name) + + else: + self.stack.append(name) + + def smart_visit(self, node: Sequence[ast.AST] | ast.AST | None): + if node is None: + return + + if isinstance(node, Sequence): + for n in node: + self.smart_visit(n) + else: + self.visit(node) diff --git a/src/ssort/_requirements.py b/src/ssort/_requirements.py index 15a85ea..078f759 100644 --- a/src/ssort/_requirements.py +++ b/src/ssort/_requirements.py @@ -8,6 +8,8 @@ from ssort._bindings import get_bindings from ssort._builtins import CLASS_BUILTINS +from ._node_visitor import SmartNodeVisitor + __all__ = ["get_requirements", "Requirement"] @@ -44,24 +46,11 @@ def get_requirements(node: ast.AST): yield from requirements.stack -class Requirements(ast.NodeVisitor): - def __init__(self): - self.stack = [] - - def flexible_visit(self, node: Sequence[ast.AST] | ast.AST | None): - if node is None: - return - - if isinstance(node, Sequence): - for n in node: - self.flexible_visit(n) - else: - self.visit(node) - +class Requirements(SmartNodeVisitor): def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef): - self.flexible_visit(node.decorator_list) - self.flexible_visit(node.args) - self.flexible_visit(node.returns) + self.smart_visit(node.decorator_list) + self.smart_visit(node.args) + self.smart_visit(node.returns) scope = get_scope_from_arguments(node.args) @@ -88,8 +77,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef): visit_AsyncFunctionDef = visit_FunctionDef def visit_ClassDef(self, node: ast.ClassDef): - self.flexible_visit(node.decorator_list) - self.flexible_visit(node.bases) + self.smart_visit(node.decorator_list) + self.smart_visit(node.bases) scope = set(CLASS_BUILTINS) @@ -103,8 +92,8 @@ def visit_ClassDef(self, node: ast.ClassDef): def visit_For(self, node: ast.For | ast.AsyncFor): bindings = set(get_bindings(node)) - self.flexible_visit(node.target) - self.flexible_visit(node.iter) + self.smart_visit(node.target) + self.smart_visit(node.iter) for stmt in node.body: for requirement in get_requirements(stmt): @@ -121,7 +110,7 @@ def visit_For(self, node: ast.For | ast.AsyncFor): def visit_With(self, node: ast.With | ast.AsyncWith): bindings = set(get_bindings(node)) - self.flexible_visit(node.items) + self.smart_visit(node.items) for stmt in node.body: for requirement in get_requirements(stmt): @@ -153,7 +142,7 @@ def visit_Nonlocal(self, node: ast.Nonlocal): ) def visit_Lambda(self, node: ast.Lambda): - self.flexible_visit(node.args) + self.smart_visit(node.args) scope = get_scope_from_arguments(node.args) scope.update(get_bindings(node.body))