Skip to content

Commit

Permalink
Edit operator metatypes to include torch.nn.hardswish_ and torch.drop…
Browse files Browse the repository at this point in the history
…out_
  • Loading branch information
anzr299 committed Jul 8, 2024
1 parent 125a10f commit d9e4c41
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class PTHardTanhMetatype(PTOperatorMetatype):
@PT_OPERATOR_METATYPES.register()
class PTHardSwishMetatype(PTOperatorMetatype):
name = "HardSwishOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardswish"]}
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardswish", "hardswish_"]}
num_expected_input_edges = 1


Expand Down Expand Up @@ -693,7 +693,7 @@ class PTRoundMetatype(PTOperatorMetatype):
@PT_OPERATOR_METATYPES.register()
class PTDropoutMetatype(PTOperatorMetatype):
name = "DropoutOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["dropout"]}
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["dropout"], NamespaceTarget.TORCH: ["dropout_"]}


@PT_OPERATOR_METATYPES.register()
Expand Down

0 comments on commit d9e4c41

Please sign in to comment.