-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add stricter type checking and expected matrix sizes (#42)
--------- Co-authored-by: Corwin Joy <[email protected]> Co-authored-by: Danny Friar <[email protected]> Co-authored-by: Max Balandat <[email protected]> Co-authored-by: Geoff Pleiss <[email protected]>
- Loading branch information
1 parent
5191ee1
commit 32ba847
Showing
54 changed files
with
2,015 additions
and
936 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Pull Request CI | ||
|
||
on: | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
run_linter: | ||
uses: ./.github/workflows/run_linter.yml | ||
|
||
run_test_suite: | ||
uses: ./.github/workflows/run_test_suite.yml | ||
|
||
run_small_type_checked_test_suite: | ||
uses: ./.github/workflows/run_type_checked_test_suite.yml | ||
with: | ||
files_to_test: "test/operators/test_dense_linear_operator.py test/operators/test_diag_linear_operator.py test/operators/test_kronecker_product_linear_operator.py" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Push to Main CI | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
|
||
jobs: | ||
run_linter: | ||
uses: ./.github/workflows/run_linter.yml | ||
|
||
run_test_suite: | ||
uses: ./.github/workflows/run_type_checked_test_suite.yml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Run Linter | ||
|
||
on: | ||
workflow_call: | ||
|
||
jobs: | ||
run_linter: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
pip install flake8==5.0.4 flake8-print==5.0.0 pre-commit | ||
pre-commit install | ||
pre-commit run seed-isort-config || true | ||
- name: Run linting | ||
run: | | ||
flake8 | ||
- name: Run pre-commit checks | ||
# skipping flake8 here (run separatey above b/c pre-commit does not include flake8-print) | ||
run: | | ||
SKIP=flake8 pre-commit run --files test/**/*.py linear_operator/**/*.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Run Type Checked Test Suite | ||
|
||
on: | ||
workflow_call: | ||
inputs: | ||
files_to_test: | ||
required: false | ||
type: string | ||
|
||
jobs: | ||
run_type_checked_unit_tests: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
pytorch-version: ["latest", "stable"] | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
if [[ ${{ matrix.pytorch-version }} = "latest" ]]; then | ||
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; | ||
else | ||
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html; | ||
fi | ||
pip install -e ".[test]" | ||
- name: Run unit tests | ||
run: | | ||
pytest ${{ inputs.files_to_test }} --jaxtyping-packages=linear_operator,typeguard.typechecked |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/usr/bin/env bash | ||
echo "HI" 1>&2 | ||
CHANGED_FILES=$(git diff --cached --name-only | grep linear_operator/operators) | ||
echo $CHANGED_FILES 1>&2 | ||
if [[ -n "$CHANGED_FILES" ]]; then | ||
python ./.hooks/propagate_type_hints.py | ||
else | ||
echo "NO CHANGED FILES" 1>&2 | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Propagate type hints & signatures defined in _linear_operator.py to derived classes. | ||
# Here we leverage libcst which can preserve the original whitespace & formatting of the file | ||
# The idea is that we only want to change the type hints. | ||
# This way we can enforce consistency between the base class signature and derived signatures. | ||
|
||
import os | ||
from pathlib import Path | ||
from typing import List, Optional, Tuple, TypedDict | ||
|
||
import libcst as cst | ||
|
||
|
||
class Annotations(TypedDict): | ||
key: Tuple[str, ...] # key: tuple of canonical class/function name | ||
value: Tuple[cst.Parameters, Optional[cst.Annotation]] # value: (params, returns) | ||
|
||
|
||
class TypingCollector(cst.CSTVisitor): | ||
def __init__(self) -> None: | ||
# stack for storing the canonical name of the current function | ||
self.stack: List[Tuple[str, ...]] = [] | ||
# store the annotations | ||
self.annotations: Annotations = {} | ||
|
||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: | ||
self.stack.append(node.name.value) | ||
|
||
def leave_ClassDef(self, node: cst.ClassDef) -> None: | ||
self.stack.pop() | ||
|
||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: | ||
self.stack.append(node.name.value) | ||
self.annotations[tuple(self.stack)] = (node.params, node.returns) | ||
return False # pyi files don't support inner functions, return False to stop the traversal. | ||
|
||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None: | ||
self.stack.pop() | ||
|
||
|
||
class TypingTransformer(cst.CSTTransformer): | ||
|
||
# List of LinearOperator functions we do not want to propagate the signature from | ||
excluded_functions = ["__init__", "_check_args", "__torch_function__"] | ||
|
||
def __init__(self, annotations: Annotations): | ||
# stack for storing the canonical name of the current function | ||
self.stack: List[Tuple[str, ...]] = [] | ||
# store the annotations | ||
self.annotations: Annotations = annotations | ||
|
||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: | ||
self.stack.append(node.name.value) | ||
|
||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode: | ||
self.stack.pop() | ||
return updated_node | ||
|
||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: | ||
self.stack.append(node.name.value) | ||
return False # pyi files don't support inner functions, return False to stop the traversal. | ||
|
||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode: | ||
key = tuple(self.stack) | ||
if key[-1] in TypingTransformer.excluded_functions: | ||
return updated_node | ||
try: | ||
if original_node.params.params[0].name.value != "self": # Assume this is not a class method | ||
return updated_node | ||
except Exception: | ||
return updated_node | ||
key = ("LinearOperator", key[-1]) | ||
self.stack.pop() | ||
if key in self.annotations: | ||
annotations = self.annotations[key] | ||
return updated_node.with_changes(params=annotations[0], returns=annotations[1]) | ||
return updated_node | ||
|
||
|
||
def collect_base_type_hints(base_filename: Path) -> TypingCollector: | ||
base_tree = cst.parse_module(base_filename.read_text()) | ||
base_visitor = TypingCollector() | ||
base_tree.visit(base_visitor) | ||
return base_visitor | ||
|
||
|
||
def copy_base_type_hints_to_derived(target: Path, base_visitor: TypingCollector) -> cst.Module: | ||
source_tree = cst.parse_module(target.read_text()) | ||
transformer = TypingTransformer(base_visitor.annotations) | ||
modified_tree = source_tree.visit(transformer) | ||
return modified_tree | ||
|
||
|
||
def main(): | ||
directory = "linear_operator/operators" | ||
base_filename = Path(directory) / "_linear_operator.py" | ||
base_visitor = collect_base_type_hints(base_filename) | ||
|
||
os.environ["TYPE_HINTS_PROPAGATED"] = "0" | ||
changed_files = [] | ||
|
||
pathlist = Path(directory).glob("*.py") | ||
for path in pathlist: | ||
if path.name[0] == "_": | ||
continue | ||
target = path | ||
target_out = path | ||
original_code = target.read_text() | ||
modified_code = copy_base_type_hints_to_derived(target, base_visitor).code | ||
if original_code != modified_code: | ||
changed_files.append(path) | ||
with open(target_out, "w") as f: | ||
f.write(modified_code) | ||
|
||
if len(changed_files): | ||
print("The following files have been changed:") # noqa T201 | ||
for changed_file in changed_files: | ||
print(f" - {changed_file}") # noqa T201 | ||
os.environ["TYPE_HINTS_PROPAGATED"] = "1" | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/usr/bin/env bash | ||
if [[ -n "$(echo $@ | grep linear_operator/operators)" ]]; then | ||
python ./.hooks/propagate_type_hints.py | ||
if [[ $TYPE_HINTS_PROPAGATED = 1 ]]; then | ||
exit 2 | ||
fi | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
jaxtyping>=0.2.9 | ||
myst-parser | ||
setuptools_scm | ||
sphinx | ||
sphinx_rtd_theme | ||
sphinx-autodoc-typehints | ||
six | ||
uncompyle6 |
Oops, something went wrong.