Skip to content

Commit

Permalink
torchfix: Refactor ERROR_CODE to be consistent (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
seemethere authored Apr 22, 2024
1 parent 86579eb commit b2d55f8
Show file tree
Hide file tree
Showing 13 changed files with 187 additions and 114 deletions.
4 changes: 2 additions & 2 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def test_errorcodes_distinct():
seen = set()
for visitor in visitors:
LOGGER.info("Checking error code for %s", visitor.__class__.__name__)
error_code = visitor.ERROR_CODE
for e in error_code if isinstance(error_code, list) else [error_code]:
errors = visitor.ERRORS
for e in errors:
assert e not in seen
seen.add(e)

Expand Down
10 changes: 3 additions & 7 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,12 @@ def main() -> None:
parser.add_argument(
"--select",
help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. "
f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. "
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. "
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
type=str,
default=None,
)
parser.add_argument(
"--version",
action="version",
version=f"{TorchFixVersion}"
)
parser.add_argument("--version", action="version", version=f"{TorchFixVersion}")

# XXX TODO: Get rid of this!
# Silence "Failed to determine module name"
Expand Down
22 changes: 17 additions & 5 deletions torchfix/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass
import sys
from abc import ABC
from dataclasses import dataclass
from typing import List, Optional, Set, Tuple

import libcst as cst
from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider
from libcst.codemod.visitors import ImportItem
from typing import Optional, List, Set, Tuple, Union
from abc import ABC
from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider

IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
CYAN = "\033[96m" if IS_TTY else ""
Expand Down Expand Up @@ -34,13 +35,24 @@ def codemod_result(self) -> str:
return f"{position} {error_code}{fixable} {self.message}"


@dataclass(frozen=True)
class TorchError:
"""Defines an error along with an explanation"""

error_code: str
message_template: str

def message(self, **kwargs):
return self.message_template.format(**kwargs)


class TorchVisitor(cst.BatchableCSTVisitor, ABC):
METADATA_DEPENDENCIES = (
QualifiedNameProvider,
WhitespaceInclusivePositionProvider,
)

ERROR_CODE: Union[str, List[str]]
ERRORS: List[TorchError]

def __init__(self) -> None:
self.violations: List[LintViolation] = []
Expand Down
27 changes: 10 additions & 17 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .visitors.internal import TorchScopedLibraryVisitor

from .visitors.performance import TorchSynchronizedDataLoaderVisitor
from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor)
from .visitors.misc import TorchRequireGradVisitor, TorchReentrantCheckpointVisitor
from .visitors.nonpublic import TorchNonPublicAliasVisitor

from .visitors.vision import (
Expand Down Expand Up @@ -48,10 +48,7 @@
def GET_ALL_ERROR_CODES():
codes = set()
for cls in ALL_VISITOR_CLS:
if isinstance(cls.ERROR_CODE, list):
codes |= set(cls.ERROR_CODE)
else:
codes.add(cls.ERROR_CODE)
codes |= set(error.error_code for error in cls.ERRORS)
return codes


Expand Down Expand Up @@ -86,16 +83,10 @@ def get_visitors_with_error_codes(error_codes):
# only correspond to one visitor.
found = False
for visitor_cls in ALL_VISITOR_CLS:
if isinstance(visitor_cls.ERROR_CODE, list):
if error_code in visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
else:
if error_code == visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
if error_code in list(err.error_code for err in visitor_cls.ERRORS):
visitor_classes.add(visitor_cls)
found = True
break
if not found:
raise AssertionError(f"Unknown error code: {error_code}")
out = []
Expand All @@ -120,8 +111,10 @@ def process_error_code_str(code_str):
if c == "ALL":
continue
if len(expand_error_codes((c,))) == 0:
raise ValueError(f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}")
raise ValueError(
f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}"
)

if "ALL" in raw_codes:
return GET_ALL_ERROR_CODES()
Expand Down
25 changes: 16 additions & 9 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import libcst as cst
import pkgutil
import yaml
from typing import Optional
from typing import Optional, List
from collections.abc import Sequence

from ...common import (
TorchVisitor,
TorchError,
call_with_name_changes,
)

Expand All @@ -16,7 +17,12 @@


class TorchDeprecatedSymbolsVisitor(TorchVisitor):
ERROR_CODE = ["TOR001", "TOR101", "TOR004", "TOR103"]
ERRORS: List[TorchError] = [
TorchError("TOR001", "Use of removed function {qualified_name}"),
TorchError("TOR101", "Import of deprecated function {qualified_name}"),
TorchError("TOR004", "Import of removed function {qualified_name}"),
TorchError("TOR103", "Import of deprecated function {qualified_name}"),
]

def __init__(self, deprecated_config_path=None):
def read_deprecated_config(path=None):
Expand Down Expand Up @@ -67,11 +73,11 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
qualified_name = f"{module}.{name.name.value}"
if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[3]
message = f"Import of deprecated function {qualified_name}"
error_code = self.ERRORS[3].error_code
message = self.ERRORS[3].message(qualified_name=qualified_name)
else:
error_code = self.ERROR_CODE[2]
message = f"Import of removed function {qualified_name}"
error_code = self.ERRORS[2].error_code
message = self.ERRORS[2].message(qualified_name=qualified_name)

reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
Expand All @@ -86,11 +92,12 @@ def visit_Call(self, node) -> None:

if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[1]
error_code = self.ERRORS[1].error_code
message = self.ERRORS[1].message(qualified_name=qualified_name)
message = f"Use of deprecated function {qualified_name}"
else:
error_code = self.ERROR_CODE[0]
message = f"Use of removed function {qualified_name}"
error_code = self.ERRORS[0].error_code
message = self.ERRORS[0].message(qualified_name=qualified_name)
replacement = self._call_replacement(node, qualified_name)

reference = self.deprecated_config[qualified_name].get("reference")
Expand Down
26 changes: 18 additions & 8 deletions torchfix/visitors/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from ...common import TorchVisitor
from ...common import TorchError, TorchVisitor


class TorchScopedLibraryVisitor(TorchVisitor):
"""
Suggest `torch.library._scoped_library` for PyTorch tests.
"""

ERROR_CODE = "TOR901"
MESSAGE = (
"Use `torch.library._scoped_library` instead of `torch.library.Library` "
"in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 "
"for details."
)
ERRORS = [
TorchError(
"TOR901",
(
"Use `torch.library._scoped_library` "
"instead of `torch.library.Library` "
"in PyTorch tests files. "
"See https://github.com/pytorch/pytorch/pull/118318 "
"for details."
),
)
]

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name == "torch.library.Library":
self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE)
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
)
35 changes: 21 additions & 14 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import libcst as cst
import libcst.matchers as m


from ...common import TorchVisitor
from ...common import TorchError, TorchVisitor


class TorchRequireGradVisitor(TorchVisitor):
"""
Find and fix common misspelling `require_grad` (instead of `requires_grad`).
"""

ERROR_CODE = "TOR002"
MESSAGE = "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?"
ERRORS = [
TorchError(
"TOR002",
"Likely typo `require_grad` in assignment. Did you mean `requires_grad`?",
)
]

def visit_Assign(self, node):
# Look for any assignment with `require_grad` attribute on the left.
Expand All @@ -33,8 +36,8 @@ def visit_Assign(self, node):
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=replacement,
)

Expand All @@ -44,12 +47,16 @@ class TorchReentrantCheckpointVisitor(TorchVisitor):
Find and fix common misuse of reentrant checkpoints.
"""

ERROR_CODE = "TOR003"
MESSAGE = (
"Please pass `use_reentrant` explicitly to `checkpoint`. "
"To maintain old behavior, pass `use_reentrant=True`. "
"It is recommended to use `use_reentrant=False`."
)
ERRORS = [
TorchError(
"TOR003",
(
"Please pass `use_reentrant` explicitly to `checkpoint`. "
"To maintain old behavior, pass `use_reentrant=True`. "
"It is recommended to use `use_reentrant=False`."
),
)
]

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
Expand All @@ -65,7 +72,7 @@ def visit_Call(self, node):
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=replacement,
)
31 changes: 24 additions & 7 deletions torchfix/visitors/nonpublic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from os.path import commonprefix
from typing import Sequence
from typing import Sequence, List

import libcst as cst
from libcst.codemod.visitors import ImportItem

from ...common import TorchVisitor
from ...common import TorchError, TorchVisitor


class TorchNonPublicAliasVisitor(TorchVisitor):
Expand All @@ -17,7 +17,20 @@ class TorchNonPublicAliasVisitor(TorchVisitor):
see https://github.com/pytorch/pytorch/pull/69862/files
"""

ERROR_CODE = ["TOR104", "TOR105"]
ERRORS: List[TorchError] = [
TorchError(
"TOR104", (
"Use of non-public function `{qualified_name}`, "
"please use `{public_name}` instead"
),
),
TorchError(
"TOR105", (
"Import of non-public function `{qualified_name}`, "
"please use `{public_name}` instead"
),
),
]

# fmt: off
ALIASES = {
Expand All @@ -33,8 +46,10 @@ def visit_Call(self, node):

if qualified_name in self.ALIASES:
public_name = self.ALIASES[qualified_name]
error_code = self.ERROR_CODE[0]
message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
error_code = self.ERRORS[0].error_code
message = self.ERRORS[0].message(
qualified_name=qualified_name, public_name=public_name
)

call_name = cst.helpers.get_full_name_for_node(node)
replacement = None
Expand Down Expand Up @@ -74,8 +89,10 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
qualified_name = f"{module}.{name.name.value}"
if qualified_name in self.ALIASES:
public_name = self.ALIASES[qualified_name]
error_code = self.ERROR_CODE[1]
message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
error_code = self.ERRORS[1].error_code
message = self.ERRORS[1].message(
qualified_name=qualified_name, public_name=public_name
)

new_module = ".".join(public_name.split(".")[:-1])
new_name = public_name.split(".")[-1]
Expand Down
23 changes: 14 additions & 9 deletions torchfix/visitors/performance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import libcst.matchers as m


from ...common import TorchVisitor
from ...common import TorchError, TorchVisitor


class TorchSynchronizedDataLoaderVisitor(TorchVisitor):
Expand All @@ -10,12 +9,16 @@ class TorchSynchronizedDataLoaderVisitor(TorchVisitor):
https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py
"""

ERROR_CODE = "TOR401"
MESSAGE = (
"Detected DataLoader running with synchronized implementation. "
"Please enable asynchronous dataloading by setting num_workers > 0 when "
"initializing DataLoader."
)
ERRORS = [
TorchError(
"TOR401",
(
"Detected DataLoader running with synchronized implementation."
" Please enable asynchronous dataloading by setting "
"num_workers > 0 when initializing DataLoader."
),
)
]

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
Expand All @@ -25,5 +28,7 @@ def visit_Call(self, node):
num_workers_arg.value, m.Integer(value="0")
):
self.add_violation(
node, error_code=self.ERROR_CODE, message=self.MESSAGE
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
)
Loading

0 comments on commit b2d55f8

Please sign in to comment.