Skip to content

Commit

Permalink
Update on "Reuse GELU implementation from PyTorch core"
Browse files Browse the repository at this point in the history
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch.

Note that, because we will pick up Sleef internally and ignore it
externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS.

Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break.

Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/)

[ghstack-poisoned]
  • Loading branch information
swolchok committed Dec 2, 2024
2 parents ca0fa70 + 05e9a40 commit f5a9843
Show file tree
Hide file tree
Showing 60 changed files with 1,798 additions and 479 deletions.
11 changes: 11 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,14 @@ python_library(
"//executorch/backends/arm/operators:node_visitor",
],
)

python_library(
name = "arm_model_evaluator",
src = [
"util/arm_model_evaluator.py",
],
typing = True,
deps = [
"//caffe2:torch",
]
)
6 changes: 3 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
InsertSqueezeAfterSumPass,
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
Expand Down Expand Up @@ -71,7 +71,7 @@ def transform_to_backend_pipeline(
self.add_pass(DecomposeMeanDimPass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(InsertSqueezeAfterSumPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
Expand Down
58 changes: 58 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-unsafe

from inspect import isclass
from typing import Optional

import torch
Expand Down Expand Up @@ -133,3 +134,60 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
fake_tensor, FakeTensor
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
return fake_tensor


def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
"""
Help-function for getting a value from node.args/ kwargs, three cases:
1. By position in node.args - Returns arg at given position or default_value if index is one out of bounds
2. By key in node.kwargs - Returns kwarg with given key or default_value if it deos not exist
3. By type in node.args - Returns first arg of args of given type. Useful for cases where arg postions may differ but types are unique.
"""
if isinstance(key, int):
if 0 <= key < len(args):
return args[key]
elif key == len(args):
if default_value is not None:
return default_value
else:
raise RuntimeError(f"No defult value given for index {key}")
else:
raise RuntimeError(
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
)
elif isinstance(key, str):
return args.get(key, default_value)
elif isclass(key):
for arg in args:
if isinstance(arg, key):
return arg
if default_value is not None:
return default_value
else:
raise RuntimeError(f"No arg of type {key}")
else:
raise RuntimeError("Invalid type")


def set_node_arg(node: torch.fx.Node, i: int | str, value):
"""
Help-function for setting a value in node.args/ kwargs. If the index is one larger than the list size, the value is instead appended to the list.
"""
if isinstance(i, int):
if 0 <= i < len(node.args):
args = list(node.args)
args[i] = value
node.args = tuple(args)
return
elif i == len(node.args):
node.args = node.args + (value,)
else:
raise RuntimeError(
f"Out of bounds index {i} for setting value in {node} args (of size {len(node.args)})"
)
elif isinstance(i, str):
kwargs = dict(node.kwargs)
kwargs[i] = value
node.kwargs = kwargs
else:
raise RuntimeError("Invalid type")
13 changes: 7 additions & 6 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -42,16 +43,16 @@ def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
return super().call_operator(op, args, kwargs, meta)

x = args[0]
dim = args[1]
keepdim = args[2] if len(args) > 2 else False
if not keepdim:
return super().call_operator(op, args, kwargs, meta)
# if keepdim == True and dim == [-1, -2], mean.dim can be
x = get_node_arg(args, 0)
dim = get_node_arg(args, 1)
keepdim = get_node_arg(args, 2, False)

# if dim == [-1, -2], mean.dim can be
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
if dim == [-1, -2]:
# Simply return the mean.dim operator for future decomposition.
return super().call_operator(op, args, kwargs, meta)

shape = meta["val"].size()
dtype = meta["val"].dtype
input_shape = x.data.size()
Expand Down
27 changes: 16 additions & 11 deletions backends/arm/_passes/decompose_var_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import torch
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -53,26 +54,30 @@ def call_operator(self, op, args, kwargs, meta):
torch.ops.aten.var.dim,
):
return super().call_operator(op, args, kwargs, meta)
shape = meta["val"].size()

x = args[0]
input_shape = x.data.size()
shape = list(meta["val"].size())
if shape == []:
shape = [1 for _ in input_shape]

dtype = meta["val"].dtype
dim = args[1] if len(args) > 1 else list(range(len(shape)))
# Get dim from args based on argument type
dim = get_node_arg(args, key=list, default_value=list(range(len(shape))))

if op == torch.ops.aten.var.dim:
correction = args[-2]
keepdim = args[-1]
keepdim = get_node_arg(args, bool, False)
correction = get_node_arg(args, int, 1)
else:
correction = kwargs["correction"]
keepdim = kwargs.get("keepdim", False)
if not keepdim:
return super().call_operator(op, args, kwargs, meta)
correction = get_node_arg(kwargs, "correction", 1)
keepdim = get_node_arg(kwargs, "keepdim", False)

x = args[0]
input_shape = x.data.size()
N = 1
for d in dim:
N *= input_shape[d]

mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
diff = super().call_operator(diff_op, (x, mean), {}, meta)
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@

import torch
import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_node_arg,
set_node_arg,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class InsertSqueezeAfterSumPass(ExportPass):
class KeepDimsFalseToSqueezePass(ExportPass):
"""
In Pytorch, the default behaviour of Tensor.sum is to squeeze
In Pytorch, the default behaviour of for example Tensor.sum is to squeeze
the dimension that is summed (keep_dim = False).
However, in TOSA, REDUCE_SUM always preserves the
rank of the input (keep_dim = True).
Expand All @@ -31,28 +35,52 @@ class InsertSqueezeAfterSumPass(ExportPass):
squeeze(dim = dims)
"""

# CURRENTLY NOT HANDLED OPS
# exir_ops.edge.aten.amax,
# exir_ops.edge.aten.amin,
# exir_ops.edge.aten.any.dim,
# exir_ops.edge.aten.any.dims,
# exir_ops.edge.aten.argmax,
# exir_ops.edge.aten.argmin,
# exir_ops.edge.aten.max.dim,
# exir_ops.edge.aten.min.dim,
# exir_ops.edge.aten.prod.dim_int,

# HANDLED OPS
# exir_ops.edge.aten.sum.dim_IntList
# exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
# exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
# exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
keep_dim_index = None

if node.op != "call_function":
continue
if node.target != exir_ops.edge.aten.sum.dim_IntList:
if node.target == exir_ops.edge.aten.sum.dim_IntList:
keep_dim_index = 2
else:
continue

sum_node = cast(torch.fx.Node, node)
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
keep_dim = get_node_arg(sum_node.args, keep_dim_index, False)

if keep_dim:
continue

dim_list = cast(list[int], sum_node.args[1])
dim_list = get_node_arg(sum_node.args, 1, [0])

# Add keep_dim = True arg to sum node.
sum_node.args = sum_node.args[0:2] + (True,)
set_node_arg(sum_node, 2, True)

with graph_module.graph.inserting_after(sum_node):
squeeze_node = create_node(
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
)
sum_node.replace_all_uses_with(squeeze_node)
squeeze_node.args = (sum_node, dim_list)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
7 changes: 1 addition & 6 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,4 @@

# pyre-unsafe

from . import ( # noqa
mean_dim_support,
right_shift_support,
tosa_supported_operators,
var_correction_support,
)
from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa
33 changes: 0 additions & 33 deletions backends/arm/operator_support/mean_dim_support.py

This file was deleted.

Loading

0 comments on commit f5a9843

Please sign in to comment.