Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for TorchVisionModelsImportVisitor #29

Merged
merged 2 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/fixtures/vision/checker/models_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from torchvision.models import resnet50, resnet101
import torchvision.models
from torchvision.models import *
import torchvision.models as models, torch
1 change: 1 addition & 0 deletions tests/fixtures/vision/checker/models_import.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'.
6:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'.
5 changes: 5 additions & 0 deletions tests/fixtures/vision/codemod/models_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import torchvision.models as models
import torchvision.models as cnn

# don't touch if more than one name imported
import torchvision.models as models, torch
5 changes: 5 additions & 0 deletions tests/fixtures/vision/codemod/models_import.py.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from torchvision import models
import torchvision.models as cnn

# don't touch if more than one name imported
import torchvision.models as models, torch
22 changes: 14 additions & 8 deletions torchfix/visitors/vision/models_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

class TorchVisionModelsImportVisitor(TorchVisitor):
ERROR_CODE = "TOR203"
MESSAGE = (
"Consider replacing 'import torchvision.models as models' "
"with 'from torchvision import models'."
)

def visit_Import(self, node: cst.Import) -> None:
replacement = None
for imported_item in node.names:
if isinstance(imported_item.name, cst.Attribute):
# TODO refactor using libcst.matchers.matches
if (
isinstance(imported_item.name.value, cst.Name)
and imported_item.name.value.value == "torchvision"
Expand All @@ -21,20 +27,20 @@ def visit_Import(self, node: cst.Import) -> None:
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
# Replace only if the import statement has no other names
if len(node.names) == 1:
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=(
"Consider replacing 'import torchvision.models as"
" models' with 'from torchvision import models'."
),
message=self.MESSAGE,
line=position.start.line,
column=position.start.column,
node=node,
replacement=replacement
)
)
break
Loading