Skip to content

Commit

Permalink
Add stricter type checking and expected matrix sizes (#42)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Corwin Joy <[email protected]>
Co-authored-by: Danny Friar <[email protected]>
Co-authored-by: Max Balandat <[email protected]>
Co-authored-by: Geoff Pleiss <[email protected]>
  • Loading branch information
5 people authored May 2, 2023
1 parent 5191ee1 commit 32ba847
Show file tree
Hide file tree
Showing 54 changed files with 2,015 additions and 936 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ on:
types: [created]

jobs:
run_linter:
uses: ./.github/workflows/run_linter.yml

run_test_suite:
uses: ./.github/workflows/run_test_suite.yml
uses: ./.github/workflows/run_type_checked_test_suite.yml

deploy_pypi:
runs-on: ubuntu-latest
Expand Down
20 changes: 20 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Pull Request CI

on:
pull_request:
branches: [ main ]

jobs:
run_linter:
uses: ./.github/workflows/run_linter.yml

run_test_suite:
uses: ./.github/workflows/run_test_suite.yml

run_small_type_checked_test_suite:
uses: ./.github/workflows/run_type_checked_test_suite.yml
with:
files_to_test: "test/operators/test_dense_linear_operator.py test/operators/test_diag_linear_operator.py test/operators/test_kronecker_product_linear_operator.py"
15 changes: 15 additions & 0 deletions .github/workflows/push_to_main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Push to Main CI

on:
push:
branches: [ main ]

jobs:
run_linter:
uses: ./.github/workflows/run_linter.yml

run_test_suite:
uses: ./.github/workflows/run_type_checked_test_suite.yml
29 changes: 29 additions & 0 deletions .github/workflows/run_linter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Run Linter

on:
workflow_call:

jobs:
run_linter:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
- name: Install dependencies
run: |
pip install flake8==5.0.4 flake8-print==5.0.0 pre-commit
pre-commit install
pre-commit run seed-isort-config || true
- name: Run linting
run: |
flake8
- name: Run pre-commit checks
# skipping flake8 here (run separatey above b/c pre-commit does not include flake8-print)
run: |
SKIP=flake8 pre-commit run --files test/**/*.py linear_operator/**/*.py
27 changes: 1 addition & 26 deletions .github/workflows/run_test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,9 @@
name: Run Test Suite

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_call:

jobs:
run_linter:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
- name: Install dependencies
run: |
pip install flake8==5.0.4 flake8-print==5.0.0 pre-commit
pre-commit install
pre-commit run seed-isort-config || true
- name: Run linting
run: |
flake8
- name: Run pre-commit checks
# skipping flake8 here (run separatey above b/c pre-commit does not include flake8-print)
run: |
SKIP=flake8 pre-commit run --files test/**/*.py linear_operator/**/*.py
run_unit_tests:
runs-on: ubuntu-latest
strategy:
Expand All @@ -50,7 +25,7 @@ jobs:
else
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html;
fi
pip install -e .
pip install -e ".[test]"
- name: Run unit tests
run: |
python -m unittest discover
35 changes: 35 additions & 0 deletions .github/workflows/run_type_checked_test_suite.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Run Type Checked Test Suite

on:
workflow_call:
inputs:
files_to_test:
required: false
type: string

jobs:
run_type_checked_unit_tests:
runs-on: ubuntu-latest
strategy:
matrix:
pytorch-version: ["latest", "stable"]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
- name: Install dependencies
run: |
if [[ ${{ matrix.pytorch-version }} = "latest" ]]; then
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html;
else
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html;
fi
pip install -e ".[test]"
- name: Run unit tests
run: |
pytest ${{ inputs.files_to_test }} --jaxtyping-packages=linear_operator,typeguard.typechecked
9 changes: 9 additions & 0 deletions .hooks/check_type_hints.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env bash
echo "HI" 1>&2
CHANGED_FILES=$(git diff --cached --name-only | grep linear_operator/operators)
echo $CHANGED_FILES 1>&2
if [[ -n "$CHANGED_FILES" ]]; then
python ./.hooks/propagate_type_hints.py
else
echo "NO CHANGED FILES" 1>&2
fi
121 changes: 121 additions & 0 deletions .hooks/propagate_type_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Propagate type hints & signatures defined in _linear_operator.py to derived classes.
# Here we leverage libcst which can preserve the original whitespace & formatting of the file
# The idea is that we only want to change the type hints.
# This way we can enforce consistency between the base class signature and derived signatures.

import os
from pathlib import Path
from typing import List, Optional, Tuple, TypedDict

import libcst as cst


class Annotations(TypedDict):
key: Tuple[str, ...] # key: tuple of canonical class/function name
value: Tuple[cst.Parameters, Optional[cst.Annotation]] # value: (params, returns)


class TypingCollector(cst.CSTVisitor):
def __init__(self) -> None:
# stack for storing the canonical name of the current function
self.stack: List[Tuple[str, ...]] = []
# store the annotations
self.annotations: Annotations = {}

def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.stack.append(node.name.value)

def leave_ClassDef(self, node: cst.ClassDef) -> None:
self.stack.pop()

def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.stack.append(node.name.value)
self.annotations[tuple(self.stack)] = (node.params, node.returns)
return False # pyi files don't support inner functions, return False to stop the traversal.

def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
self.stack.pop()


class TypingTransformer(cst.CSTTransformer):

# List of LinearOperator functions we do not want to propagate the signature from
excluded_functions = ["__init__", "_check_args", "__torch_function__"]

def __init__(self, annotations: Annotations):
# stack for storing the canonical name of the current function
self.stack: List[Tuple[str, ...]] = []
# store the annotations
self.annotations: Annotations = annotations

def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.stack.append(node.name.value)

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode:
self.stack.pop()
return updated_node

def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.stack.append(node.name.value)
return False # pyi files don't support inner functions, return False to stop the traversal.

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode:
key = tuple(self.stack)
if key[-1] in TypingTransformer.excluded_functions:
return updated_node
try:
if original_node.params.params[0].name.value != "self": # Assume this is not a class method
return updated_node
except Exception:
return updated_node
key = ("LinearOperator", key[-1])
self.stack.pop()
if key in self.annotations:
annotations = self.annotations[key]
return updated_node.with_changes(params=annotations[0], returns=annotations[1])
return updated_node


def collect_base_type_hints(base_filename: Path) -> TypingCollector:
base_tree = cst.parse_module(base_filename.read_text())
base_visitor = TypingCollector()
base_tree.visit(base_visitor)
return base_visitor


def copy_base_type_hints_to_derived(target: Path, base_visitor: TypingCollector) -> cst.Module:
source_tree = cst.parse_module(target.read_text())
transformer = TypingTransformer(base_visitor.annotations)
modified_tree = source_tree.visit(transformer)
return modified_tree


def main():
directory = "linear_operator/operators"
base_filename = Path(directory) / "_linear_operator.py"
base_visitor = collect_base_type_hints(base_filename)

os.environ["TYPE_HINTS_PROPAGATED"] = "0"
changed_files = []

pathlist = Path(directory).glob("*.py")
for path in pathlist:
if path.name[0] == "_":
continue
target = path
target_out = path
original_code = target.read_text()
modified_code = copy_base_type_hints_to_derived(target, base_visitor).code
if original_code != modified_code:
changed_files.append(path)
with open(target_out, "w") as f:
f.write(modified_code)

if len(changed_files):
print("The following files have been changed:") # noqa T201
for changed_file in changed_files:
print(f" - {changed_file}") # noqa T201
os.environ["TYPE_HINTS_PROPAGATED"] = "1"


main()
7 changes: 7 additions & 0 deletions .hooks/propagate_type_hints.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env bash
if [[ -n "$(echo $@ | grep linear_operator/operators)" ]]; then
python ./.hooks/propagate_type_hints.py
if [[ $TYPE_HINTS_PROPAGATED = 1 ]]; then
exit 2
fi
fi
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ repos:
hooks:
- id: forbid-crlf
- id: forbid-tabs
- repo: local
hooks:
- id: propagate-type-hints
name: Propagate Type Hints
entry: ./.hooks/propagate_type_hints.sh
language: script
pass_filenames: true
require_serial: true
34 changes: 30 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,38 @@ We use [standard sphinx docstrings](https://sphinx-rtd-tutorial.readthedocs.io/e
LinearOperator aims to be fully typed using Python 3.8+
[type hints](https://www.python.org/dev/peps/pep-0484/).
We expect any contributions to also use proper type annotations.
While we currently do not enforce full consistency of these in our continuous integration
test, you should strive to type check your code locally.
For this we recommend using [pyre](https://pyre-check.org/).

We are using [jaxtyping](https://github.com/google/jaxtyping) to help us be declarative about the dimension sizes used
in the LinearOperator methods.
The use of [jaxtyping](https://github.com/google/jaxtyping) makes it clearer what the functions are doing algebraically
and where broadcasting is happening.

These type hints are checked in the unit tests by using
[typeguard](https://github.com/agronholm/typeguard) to perform run-time
checking of the signatures to make sure they are accurate.
The signatures are written into the base linear operator class in `_linear_oparator.py`.
These signatures are then copied to the derived classes by running the script
`python ./.hooks/propagate_type_hints.py`.
This is done for:
1. Consistency. Make sure the derived implementations are following the promised interface.
2. Visibility. Make it easy to see what the expected signature is, along with dimensions. Repeating the signature in the derived classes enhances readability.
3. Necessity. The way that jaxtyping and typeguard are written, they won't run type checks unless type annotations are present in the derived method signature.

In short, if you want to update the type hints, update the code in the LinearOperator class in
`_linear_oparator.py` then run `python ./.hooks/propagate_type_hints.py` to copy the new signature to the derived
classes.

### Unit Tests

#### With type checking (slower, but more thorough)
To run the unittests with type checking, run
```bash
pytest --jaxtyping-packages=linear_operator,typeguard.typechecked
```

- To run tests within a specific directory, run (e.g.) `pytest test/operators --jaxtyping-packages=linear_operator,typeguard.typechecked`.
- To run a specific file, run (e.g.) `pytest test/operators/test_matmul_linear_operator.py --jaxtyping-packages=linear_operator,typeguard.typechecked`.

#### Without type checking (faster, but less thorough)
We use python's `unittest` to run unit tests:
```bash
python -m unittest
Expand All @@ -56,6 +81,7 @@ python -m unittest
- To run a specific unit test, run (e.g.) `python -m unittest test.operators.test_matmul_linear_operator.TestMatmulLinearOperator.test_matmul_vec`.



### Documentation

LinearOperator uses sphinx to generate documentation, and ReadTheDocs to host documentation.
Expand Down
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
jaxtyping>=0.2.9
myst-parser
setuptools_scm
sphinx
sphinx_rtd_theme
sphinx-autodoc-typehints
six
uncompyle6
Loading

0 comments on commit 32ba847

Please sign in to comment.