Skip to content

Commit

Permalink
Move vulkan.passes to vulkan._passes (#5919) (#6129)
Browse files Browse the repository at this point in the history
Summary:
Changing vulkan.passes to vulkan._passes to indicate that these passes are not covered under the API stability guarantee.

Pull Request resolved: #5919

Reviewed By: helunwencser

Differential Revision: D63926849

fbshipit-source-id: bf135c46c6718bc37afa640cf51d004891516575
(cherry picked from commit e1832ef)
  • Loading branch information
tarun292 authored Oct 11, 2024
1 parent 9aa96f8 commit 905a26f
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 2 deletions.
4 changes: 3 additions & 1 deletion backends/transforms/fuse_conv_with_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import sys

import torch
from executorch.backends.vulkan.passes.custom_ops_defs import conv_with_clamp_op # noqa
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
conv_with_clamp_op,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ runtime.python_library(
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:mean_to_sum_div",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/vulkan/_passes:remove_local_scalar_dense",
"//executorch/exir:graph_module",
"//executorch/exir/_serialize:_bindings",
"//executorch/exir/_serialize:lib",
Expand Down
File renamed without changes.
File renamed without changes.
44 changes: 44 additions & 0 deletions backends/vulkan/_passes/remove_local_scalar_dense_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
"""
Remove local_scalar_dense op nodes and replace uses with parent node, or the
original scalar tensor.
"""
target_op = torch.ops.aten._local_scalar_dense.default
for node in graph.nodes:
if node.op == "call_function" and node.target == target_op:
replace_node = node.args[0]
# If the argument to the local_scalar_dense op is a select op with only
# one user, and the argument to the select op is a tensor with only one
# element (i.e. a scalar tensor), then replace the entire pattern with the
# scalar tensor.
if (
replace_node.op == "call_function"
and replace_node.target == exir_ops.edge.aten.select_copy.int
):
if replace_node.args[0].meta["val"].numel() == 1:
replace_node = replace_node.args[0]

with graph.inserting_after(node):
node.replace_all_uses_with(replace_node)

graph.eliminate_dead_code()
return graph


class RemoveLocalScalarDenseOpsTransform(ExportPass):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph)
return PassResult(graph_module, True)
File renamed without changes.
2 changes: 1 addition & 1 deletion backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import operator

from executorch.backends.vulkan.passes.custom_ops_defs import ( # noqa
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
conv_with_clamp_op,
grid_priors_op,
)
Expand Down

0 comments on commit 905a26f

Please sign in to comment.