diff --git a/backends/transforms/fuse_conv_with_clamp.py b/backends/transforms/fuse_conv_with_clamp.py index 3f18402ee3..15973cae57 100644 --- a/backends/transforms/fuse_conv_with_clamp.py +++ b/backends/transforms/fuse_conv_with_clamp.py @@ -7,7 +7,9 @@ import sys import torch -from executorch.backends.vulkan.passes.custom_ops_defs import conv_with_clamp_op # noqa +from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa + conv_with_clamp_op, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 3f966bd2ff..bbddb70735 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -28,6 +28,7 @@ runtime.python_library( "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:mean_to_sum_div", "//executorch/backends/transforms:remove_clone_ops", + "//executorch/backends/vulkan/_passes:remove_local_scalar_dense", "//executorch/exir:graph_module", "//executorch/exir/_serialize:_bindings", "//executorch/exir/_serialize:lib", diff --git a/backends/vulkan/passes/TARGETS b/backends/vulkan/_passes/TARGETS similarity index 100% rename from backends/vulkan/passes/TARGETS rename to backends/vulkan/_passes/TARGETS diff --git a/backends/vulkan/passes/custom_ops_defs.py b/backends/vulkan/_passes/custom_ops_defs.py similarity index 100% rename from backends/vulkan/passes/custom_ops_defs.py rename to backends/vulkan/_passes/custom_ops_defs.py diff --git a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py new file mode 100644 index 0000000000..0f71076498 --- /dev/null +++ b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph: + """ + Remove local_scalar_dense op nodes and replace uses with parent node, or the + original scalar tensor. + """ + target_op = torch.ops.aten._local_scalar_dense.default + for node in graph.nodes: + if node.op == "call_function" and node.target == target_op: + replace_node = node.args[0] + # If the argument to the local_scalar_dense op is a select op with only + # one user, and the argument to the select op is a tensor with only one + # element (i.e. a scalar tensor), then replace the entire pattern with the + # scalar tensor. + if ( + replace_node.op == "call_function" + and replace_node.target == exir_ops.edge.aten.select_copy.int + ): + if replace_node.args[0].meta["val"].numel() == 1: + replace_node = replace_node.args[0] + + with graph.inserting_after(node): + node.replace_all_uses_with(replace_node) + + graph.eliminate_dead_code() + return graph + + +class RemoveLocalScalarDenseOpsTransform(ExportPass): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph) + return PassResult(graph_module, True) diff --git a/backends/vulkan/passes/test_custom_ops.py b/backends/vulkan/_passes/test_custom_ops.py similarity index 100% rename from backends/vulkan/passes/test_custom_ops.py rename to backends/vulkan/_passes/test_custom_ops.py diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index ca7ce72cae..903b179fcc 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -8,7 +8,7 @@ import operator -from executorch.backends.vulkan.passes.custom_ops_defs import ( # noqa +from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa conv_with_clamp_op, grid_priors_op, )