diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d30e95a5c..4c1098741 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,11 @@ repos: rev: 6.1.0 hooks: - id: flake8 +- repo: local + hooks: + - id: check-substrait-extensions + name: Check Substrait extensions + entry: pytest tests/test_extensions.py::test_read_substrait_extensions + language: python + pass_filenames: false + diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 466667266..3cb7e5294 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -1,11 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 +import os + from tests.coverage.extensions import build_type_to_short_type +# NOTE: this test is run as part of pre-commit hook def test_read_substrait_extensions(): from tests.coverage.extensions import Extension - registry = Extension.read_substrait_extensions("../extensions") + current_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_path = os.path.join(current_dir, "../extensions") + registry = Extension.read_substrait_extensions(extensions_path) assert len(registry.registry) >= 161 num_overloads = sum([len(f) for f in registry.registry.values()]) assert num_overloads >= 510