Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchfix: Refactor ERROR_CODE to be consistent #46

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def test_errorcodes_distinct():
seen = set()
for visitor in visitors:
LOGGER.info("Checking error code for %s", visitor.__class__.__name__)
error_code = visitor.ERROR_CODE
for e in error_code if isinstance(error_code, list) else [error_code]:
errors = visitor.ERRORS
for e in errors:
assert e not in seen
seen.add(e)

Expand Down
10 changes: 3 additions & 7 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,12 @@ def main() -> None:
parser.add_argument(
"--select",
help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. "
f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. "
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. "
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
type=str,
default=None,
)
parser.add_argument(
"--version",
action="version",
version=f"{TorchFixVersion}"
)
parser.add_argument("--version", action="version", version=f"{TorchFixVersion}")

# XXX TODO: Get rid of this!
# Silence "Failed to determine module name"
Expand Down
22 changes: 17 additions & 5 deletions torchfix/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass
import sys
from abc import ABC
from dataclasses import dataclass
from typing import List, Optional, Set, Tuple

import libcst as cst
from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider
from libcst.codemod.visitors import ImportItem
from typing import Optional, List, Set, Tuple, Union
from abc import ABC
from libcst.metadata import QualifiedNameProvider, WhitespaceInclusivePositionProvider

IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
CYAN = "\033[96m" if IS_TTY else ""
Expand Down Expand Up @@ -34,13 +35,24 @@ def codemod_result(self) -> str:
return f"{position} {error_code}{fixable} {self.message}"


@dataclass(frozen=True)
class TorchError:
"""Defines an error along with an explanation"""

error_code: str
message_template: str

def message(self, **kwargs):
return self.message_template.format(**kwargs)


class TorchVisitor(cst.BatchableCSTVisitor, ABC):
METADATA_DEPENDENCIES = (
QualifiedNameProvider,
WhitespaceInclusivePositionProvider,
)

ERROR_CODE: Union[str, List[str]]
ERRORS: List[TorchError]

def __init__(self) -> None:
self.violations: List[LintViolation] = []
Expand Down
27 changes: 10 additions & 17 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .visitors.internal import TorchScopedLibraryVisitor

from .visitors.performance import TorchSynchronizedDataLoaderVisitor
from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor)
from .visitors.misc import TorchRequireGradVisitor, TorchReentrantCheckpointVisitor
from .visitors.nonpublic import TorchNonPublicAliasVisitor

from .visitors.vision import (
Expand Down Expand Up @@ -48,10 +48,7 @@
def GET_ALL_ERROR_CODES():
codes = set()
for cls in ALL_VISITOR_CLS:
if isinstance(cls.ERROR_CODE, list):
codes |= set(cls.ERROR_CODE)
else:
codes.add(cls.ERROR_CODE)
codes |= set(error.error_code for error in cls.ERRORS)
return codes


Expand Down Expand Up @@ -86,16 +83,10 @@ def get_visitors_with_error_codes(error_codes):
# only correspond to one visitor.
found = False
for visitor_cls in ALL_VISITOR_CLS:
if isinstance(visitor_cls.ERROR_CODE, list):
if error_code in visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
else:
if error_code == visitor_cls.ERROR_CODE:
visitor_classes.add(visitor_cls)
found = True
break
if error_code in list(err.error_code for err in visitor_cls.ERRORS):
visitor_classes.add(visitor_cls)
found = True
break
if not found:
raise AssertionError(f"Unknown error code: {error_code}")
out = []
Expand All @@ -120,8 +111,10 @@ def process_error_code_str(code_str):
if c == "ALL":
continue
if len(expand_error_codes((c,))) == 0:
raise ValueError(f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}")
raise ValueError(
f"Invalid error code: {c}, available error "
f"codes: {list(GET_ALL_ERROR_CODES())}"
)

if "ALL" in raw_codes:
return GET_ALL_ERROR_CODES()
Expand Down
25 changes: 16 additions & 9 deletions torchfix/visitors/deprecated_symbols/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import libcst as cst
import pkgutil
import yaml
from typing import Optional
from typing import Optional, List
from collections.abc import Sequence

from ...common import (
TorchVisitor,
TorchError,
call_with_name_changes,
)

Expand All @@ -16,7 +17,12 @@


class TorchDeprecatedSymbolsVisitor(TorchVisitor):
ERROR_CODE = ["TOR001", "TOR101", "TOR004", "TOR103"]
ERRORS: List[TorchError] = [
TorchError("TOR001", "Use of removed function {qualified_name}"),
TorchError("TOR101", "Import of deprecated function {qualified_name}"),
TorchError("TOR004", "Import of removed function {qualified_name}"),
TorchError("TOR103", "Import of deprecated function {qualified_name}"),
]

def __init__(self, deprecated_config_path=None):
def read_deprecated_config(path=None):
Expand Down Expand Up @@ -67,11 +73,11 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
qualified_name = f"{module}.{name.name.value}"
if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[3]
message = f"Import of deprecated function {qualified_name}"
error_code = self.ERRORS[3].error_code
message = self.ERRORS[3].message(qualified_name=qualified_name)
else:
error_code = self.ERROR_CODE[2]
message = f"Import of removed function {qualified_name}"
error_code = self.ERRORS[2].error_code
message = self.ERRORS[2].message(qualified_name=qualified_name)

reference = self.deprecated_config[qualified_name].get("reference")
if reference is not None:
Expand All @@ -86,11 +92,12 @@ def visit_Call(self, node) -> None:

if qualified_name in self.deprecated_config:
if self.deprecated_config[qualified_name]["remove_pr"] is None:
error_code = self.ERROR_CODE[1]
error_code = self.ERRORS[1].error_code
message = self.ERRORS[1].message(qualified_name=qualified_name)
message = f"Use of deprecated function {qualified_name}"
else:
error_code = self.ERROR_CODE[0]
message = f"Use of removed function {qualified_name}"
error_code = self.ERRORS[0].error_code
message = self.ERRORS[0].message(qualified_name=qualified_name)
replacement = self._call_replacement(node, qualified_name)

reference = self.deprecated_config[qualified_name].get("reference")
Expand Down
26 changes: 18 additions & 8 deletions torchfix/visitors/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from ...common import TorchVisitor
from ...common import TorchError, TorchVisitor


class TorchScopedLibraryVisitor(TorchVisitor):
"""
Suggest `torch.library._scoped_library` for PyTorch tests.
"""

ERROR_CODE = "TOR901"
MESSAGE = (
"Use `torch.library._scoped_library` instead of `torch.library.Library` "
"in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 "
"for details."
)
ERRORS = [
TorchError(
"TOR901",
(
"Use `torch.library._scoped_library` "
"instead of `torch.library.Library` "
"in PyTorch tests files. "
"See https://github.com/pytorch/pytorch/pull/118318 "
"for details."
),
)
]

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name == "torch.library.Library":
self.add_violation(node, error_code=self.ERROR_CODE, message=self.MESSAGE)
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
)
35 changes: 21 additions & 14 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import libcst as cst
import libcst.matchers as m


from ...common import TorchVisitor
from ...common import TorchError, TorchVisitor


class TorchRequireGradVisitor(TorchVisitor):
"""
Find and fix common misspelling `require_grad` (instead of `requires_grad`).
"""

ERROR_CODE = "TOR002"
MESSAGE = "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?"
ERRORS = [
TorchError(
"TOR002",
"Likely typo `require_grad` in assignment. Did you mean `requires_grad`?",
)
]

def visit_Assign(self, node):
# Look for any assignment with `require_grad` attribute on the left.
Expand All @@ -33,8 +36,8 @@ def visit_Assign(self, node):
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=replacement,
)

Expand All @@ -44,12 +47,16 @@ class TorchReentrantCheckpointVisitor(TorchVisitor):
Find and fix common misuse of reentrant checkpoints.
"""

ERROR_CODE = "TOR003"
MESSAGE = (
"Please pass `use_reentrant` explicitly to `checkpoint`. "
"To maintain old behavior, pass `use_reentrant=True`. "
"It is recommended to use `use_reentrant=False`."
)
ERRORS = [
TorchError(
"TOR003",
(
"Please pass `use_reentrant` explicitly to `checkpoint`. "
"To maintain old behavior, pass `use_reentrant=True`. "
"It is recommended to use `use_reentrant=False`."
),
)
]

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
Expand All @@ -65,7 +72,7 @@ def visit_Call(self, node):
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=replacement,
)
31 changes: 24 additions & 7 deletions torchfix/visitors/nonpublic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from os.path import commonprefix
from typing import Sequence
from typing import Sequence, List

import libcst as cst
from libcst.codemod.visitors import ImportItem

from ...common import TorchVisitor
from ...common import TorchError, TorchVisitor


class TorchNonPublicAliasVisitor(TorchVisitor):
Expand All @@ -17,7 +17,20 @@ class TorchNonPublicAliasVisitor(TorchVisitor):
see https://github.com/pytorch/pytorch/pull/69862/files
"""

ERROR_CODE = ["TOR104", "TOR105"]
ERRORS: List[TorchError] = [
TorchError(
"TOR104", (
"Use of non-public function `{qualified_name}`, "
"please use `{public_name}` instead"
),
),
TorchError(
"TOR105", (
"Import of non-public function `{qualified_name}`, "
"please use `{public_name}` instead"
),
),
]

# fmt: off
ALIASES = {
Expand All @@ -33,8 +46,10 @@ def visit_Call(self, node):

if qualified_name in self.ALIASES:
public_name = self.ALIASES[qualified_name]
error_code = self.ERROR_CODE[0]
message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
error_code = self.ERRORS[0].error_code
message = self.ERRORS[0].message(
qualified_name=qualified_name, public_name=public_name
)

call_name = cst.helpers.get_full_name_for_node(node)
replacement = None
Expand Down Expand Up @@ -74,8 +89,10 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
qualified_name = f"{module}.{name.name.value}"
if qualified_name in self.ALIASES:
public_name = self.ALIASES[qualified_name]
error_code = self.ERROR_CODE[1]
message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
error_code = self.ERRORS[1].error_code
message = self.ERRORS[1].message(
qualified_name=qualified_name, public_name=public_name
)

new_module = ".".join(public_name.split(".")[:-1])
new_name = public_name.split(".")[-1]
Expand Down
23 changes: 14 additions & 9 deletions torchfix/visitors/performance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import libcst.matchers as m


from ...common import TorchVisitor
from ...common import TorchError, TorchVisitor


class TorchSynchronizedDataLoaderVisitor(TorchVisitor):
Expand All @@ -10,12 +9,16 @@ class TorchSynchronizedDataLoaderVisitor(TorchVisitor):
https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py
"""

ERROR_CODE = "TOR401"
MESSAGE = (
"Detected DataLoader running with synchronized implementation. "
"Please enable asynchronous dataloading by setting num_workers > 0 when "
"initializing DataLoader."
)
ERRORS = [
TorchError(
"TOR401",
(
"Detected DataLoader running with synchronized implementation."
" Please enable asynchronous dataloading by setting "
"num_workers > 0 when initializing DataLoader."
),
)
]

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
Expand All @@ -25,5 +28,7 @@ def visit_Call(self, node):
num_workers_arg.value, m.Integer(value="0")
):
self.add_violation(
node, error_code=self.ERROR_CODE, message=self.MESSAGE
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
)
Loading