diff --git a/pyproject.toml b/pyproject.toml index 87bbb70..8c27031 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ exclude = "tests/fixtures/*" [tool.mypy] exclude = ["tests/fixtures", "build"] +check_untyped_defs = true [tool.setuptools.dynamic] version = {attr = "torchfix.torchfix.__version__"} diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 5f5dff9..d8be743 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -27,9 +27,10 @@ def _codemod_results(source_path): config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES())) context = TorchCodemod(codemod.CodemodContext(filename=source_path), config) new_module = codemod.transform_module(context, code) - if isinstance(new_module, codemod.TransformFailure): + if isinstance(new_module, codemod.TransformSuccess): + return new_module.code + elif isinstance(new_module, codemod.TransformFailure): raise new_module.error - return new_module.code def test_empty(): diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 32e3b9d..4ffaaa4 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -5,7 +5,7 @@ import libcst as cst import libcst.codemod as codemod -from .common import deep_multi_replace +from .common import deep_multi_replace, TorchVisitor from .visitors.deprecated_symbols import TorchDeprecatedSymbolsVisitor from .visitors.internal import TorchScopedLibraryVisitor @@ -44,6 +44,7 @@ def GET_ALL_ERROR_CODES(): codes = set() for cls in ALL_VISITOR_CLS: + assert issubclass(cls, TorchVisitor) codes |= {error.error_code for error in cls.ERRORS} return codes @@ -79,6 +80,7 @@ def get_visitors_with_error_codes(error_codes): # only correspond to one visitor. found = False for visitor_cls in ALL_VISITOR_CLS: + assert issubclass(visitor_cls, TorchVisitor) if any(error_code == err.error_code for err in visitor_cls.ERRORS): visitor_classes.add(visitor_cls) found = True diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index f2b9cf2..cf088e8 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -29,6 +29,7 @@ def read_deprecated_config(path=None): deprecated_config = {} if path is not None: data = pkgutil.get_data("torchfix", path) + assert data is not None for item in yaml.load(data, yaml.SafeLoader): deprecated_config[item["name"]] = item return deprecated_config diff --git a/torchfix/visitors/vision/pretrained.py b/torchfix/visitors/vision/pretrained.py index af52a0f..acbe564 100644 --- a/torchfix/visitors/vision/pretrained.py +++ b/torchfix/visitors/vision/pretrained.py @@ -177,7 +177,7 @@ class TorchVisionDeprecatedPretrainedVisitor(TorchVisitor): def visit_Call(self, node): def _new_arg_and_import( - old_arg: cst.Arg, is_backbone: bool + old_arg: Optional[cst.Arg], is_backbone: bool ) -> Optional[cst.Arg]: old_arg_name = "pretrained_backbone" if is_backbone else "pretrained" if old_arg is None or (model_name, old_arg_name) not in self.MODEL_WEIGHTS: diff --git a/torchfix/visitors/vision/to_tensor.py b/torchfix/visitors/vision/to_tensor.py index 791a9e5..02a5915 100644 --- a/torchfix/visitors/vision/to_tensor.py +++ b/torchfix/visitors/vision/to_tensor.py @@ -39,4 +39,4 @@ def visit_Attribute(self, node): if len(qualified_names) != 1: return - self._maybe_add_violation(qualified_names.pop().name, node) + self._maybe_add_violation(list(qualified_names)[0].name, node)