Skip to content

Commit

Permalink
[refactoring] Extract helper method has_specific_arg (#49)
Browse files Browse the repository at this point in the history
fix #8, extract helper method `has_specific_arg` that checks for the call argument presence, and simplify all relevant call sites
  • Loading branch information
izaitsevfb authored Apr 22, 2024
1 parent b2d55f8 commit af37f69
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 44 deletions.
17 changes: 16 additions & 1 deletion torchfix/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def __init__(self) -> None:
def get_specific_arg(
node: cst.Call, arg_name: str, arg_pos: int
) -> Optional[cst.Arg]:
# `arg_pos` is zero-based.
"""
:param arg_pos: `arg_pos` is zero-based. -1 means it's a keyword argument.
:note: consider using `has_specific_arg` if you only need to check for presence.
"""
curr_pos = 0
for arg in node.args:
if arg.keyword is None:
Expand All @@ -73,6 +76,18 @@ def get_specific_arg(
return arg
return None

@staticmethod
def has_specific_arg(
node: cst.Call, arg_name: str, position: Optional[int] = None
) -> bool:
"""
Check if the specific argument is present in a call.
"""
return TorchVisitor.get_specific_arg(
node, arg_name,
position if position is not None else -1
) is not None

def add_violation(
self,
node: cst.CSTNode,
Expand Down
33 changes: 16 additions & 17 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,19 @@ class TorchReentrantCheckpointVisitor(TorchVisitor):
]

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name == "torch.utils.checkpoint.checkpoint":
use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1)
if use_reentrant_arg is None:
# This codemod maybe unsafe correctness-wise
# if reentrant behavior is actually needed,
# so the changes need to be verified/tested.
use_reentrant_arg = cst.ensure_type(
cst.parse_expression("f(use_reentrant=False)"), cst.Call
).args[0]
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=replacement,
)
if (self.get_qualified_name_for_call(node) ==
"torch.utils.checkpoint.checkpoint" and
not self.has_specific_arg(node, "use_reentrant")):
# This codemod maybe unsafe correctness-wise
# if reentrant behavior is actually needed,
# so the changes need to be verified/tested.
use_reentrant_arg = cst.ensure_type(
cst.parse_expression("f(use_reentrant=False)"), cst.Call
).args[0]
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=replacement,
)
49 changes: 23 additions & 26 deletions torchfix/visitors/security/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,27 @@ class TorchUnsafeLoadVisitor(TorchVisitor):
]

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name == "torch.load":
weights_only_arg = self.get_specific_arg(node, "weights_only", -1)
if weights_only_arg is None:
# Add `weights_only=True` if there is no `pickle_module`.
# (do not add `weights_only=False` with `pickle_module`, as it
# needs to be an explicit choice).
#
# This codemod is somewhat unsafe correctness-wise
# because full pickling functionality may still be needed
# even without `pickle_module`,
# so the changes need to be verified/tested.
replacement = None
pickle_module_arg = self.get_specific_arg(node, "pickle_module", 2)
if pickle_module_arg is None:
weights_only_arg = cst.ensure_type(
cst.parse_expression("f(weights_only=True)"), cst.Call
).args[0]
replacement = node.with_changes(
args=node.args + (weights_only_arg,)
)
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=replacement,
if self.get_qualified_name_for_call(node) == "torch.load" and \
not self.has_specific_arg(node, "weights_only"):
# Add `weights_only=True` if there is no `pickle_module`.
# (do not add `weights_only=False` with `pickle_module`, as it
# needs to be an explicit choice).
#
# This codemod is somewhat unsafe correctness-wise
# because full pickling functionality may still be needed
# even without `pickle_module`,
# so the changes need to be verified/tested.
replacement = None
if not self.has_specific_arg(node, "pickle_module", 2):
weights_only_arg = cst.ensure_type(
cst.parse_expression("f(weights_only=True)"), cst.Call
).args[0]
replacement = node.with_changes(
args=node.args + (weights_only_arg,)
)
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=replacement,
)

0 comments on commit af37f69

Please sign in to comment.