Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mypy check_untyped_defs #55

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ exclude = "tests/fixtures/*"

[tool.mypy]
exclude = ["tests/fixtures", "build"]
check_untyped_defs = true

[tool.setuptools.dynamic]
version = {attr = "torchfix.torchfix.__version__"}
5 changes: 3 additions & 2 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ def _codemod_results(source_path):
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
new_module = codemod.transform_module(context, code)
if isinstance(new_module, codemod.TransformFailure):
if isinstance(new_module, codemod.TransformSuccess):
return new_module.code
elif isinstance(new_module, codemod.TransformFailure):
raise new_module.error
return new_module.code


def test_empty():
Expand Down
4 changes: 3 additions & 1 deletion torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import libcst as cst
import libcst.codemod as codemod

from .common import deep_multi_replace
from .common import deep_multi_replace, TorchVisitor
from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor
from .visitors.internal import TorchScopedLibraryVisitor

Expand Down Expand Up @@ -44,6 +44,7 @@
def GET_ALL_ERROR_CODES():
codes = set()
for cls in ALL_VISITOR_CLS:
assert issubclass(cls, TorchVisitor)
codes |= {error.error_code for error in cls.ERRORS}
return codes

Expand Down Expand Up @@ -79,6 +80,7 @@ def get_visitors_with_error_codes(error_codes):
# only correspond to one visitor.
found = False
for visitor_cls in ALL_VISITOR_CLS:
assert issubclass(visitor_cls, TorchVisitor)
if any(error_code == err.error_code for err in visitor_cls.ERRORS):
visitor_classes.add(visitor_cls)
found = True
Expand Down
1 change: 1 addition & 0 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def read_deprecated_config(path=None):
deprecated_config = {}
if path is not None:
data = pkgutil.get_data("torchfix", path)
assert data is not None
for item in yaml.load(data, yaml.SafeLoader):
deprecated_config[item["name"]] = item
return deprecated_config
Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/vision/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor):

def visit_Call(self, node):
def _new_arg_and_import(
old_arg: cst.Arg, is_backbone: bool
old_arg: Optional[cst.Arg], is_backbone: bool
) -> Optional[cst.Arg]:
old_arg_name = "pretrained_backbone" if is_backbone else "pretrained"
if old_arg is None or (model_name, old_arg_name) not in self.MODEL_WEIGHTS:
Expand Down
2 changes: 1 addition & 1 deletion torchfix/visitors/vision/to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def visit_Attribute(self, node):
if len(qualified_names) != 1:
return

self._maybe_add_violation(qualified_names.pop().name, node)
self._maybe_add_violation(list(qualified_names)[0].name, node)