Skip to content

Commit

Permalink
Handle new dims for repeat in pass
Browse files Browse the repository at this point in the history
In addition to moving logic from node visitor,
this also fixes repeating a rank 3 tensor to
make a rank 3 tensor.

Signed-off-by: Erik Lundell <[email protected]>
Change-Id: I7090159bce47b6aa4d6613bbeb2d681d5cfcb193
  • Loading branch information
Erik-Lundell authored and freddan80 committed Dec 3, 2024
1 parent 44bcfc3 commit 4f47cc9
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 31 deletions.
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
ScalarsToAttributePass,
)
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
UnsqueezeBeforeRepeatPass,
)
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
UnsqueezeScalarPlaceholdersPass,
)
Expand All @@ -66,6 +69,7 @@ def transform_to_backend_pipeline(
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(DecomposeVarPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(DecomposeMeanDimPass())
Expand Down
62 changes: 62 additions & 0 deletions backends/arm/_passes/unsqueeze_before_repeat_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import torch
import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class UnsqueezeBeforeRepeatPass(ExportPass):
"""
A TOSA TILE op only supports rank(in) == rank(out).
To support Pytorch's repeat which can also add dimensions,
we add an explicit view op before which adds the new dimensions.
New dimensions are appendend at the front, see
https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html
Original:
repeat(multiples)
After pass:
view(shape = [1]*num_new_dims + old_shape)
repeat(multiples)
"""

def call(self, graph_module: torch.fx.GraphModule):
modified_graph = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target != exir_ops.edge.aten.repeat.default:
continue

old_shape = list(get_first_fake_tensor(node.all_input_nodes[0]).shape)
old_rank = len(old_shape)
multiples = node.args[1]
new_rank = len(multiples)
if old_rank == new_rank:
continue

num_new_dims = new_rank - old_rank
new_shape = [1] * num_new_dims + old_shape

with graph_module.graph.inserting_before(node):
view_node = create_node(
graph_module.graph,
exir_ops.edge.aten.view_copy.default,
(node.all_input_nodes[0], new_shape),
)
node.replace_input_with(node.all_input_nodes[0], view_node)
modified_graph = True

if modified_graph:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, modified_graph)
31 changes: 1 addition & 30 deletions backends/arm/operators/op_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,37 +32,8 @@ def define_node(
is_quant_node: bool,
) -> None:

item_name = inputs[0].name
shape = inputs[0].shape
rank = len(shape)
multiples = inputs[1].special
new_rank = len(multiples)

assert new_rank >= rank

# TILE only supports rank(in) == rank(out). To add more dims, we need a reshape first.
if new_rank > rank:
# Add length 1 dimensions to shape to match multiples
num_new_dims = new_rank - rank
expanded_shape = tuple(
1 if i < num_new_dims else shape[i - num_new_dims]
for i in range(new_rank)
)
expanded_shape = tosa_shape(expanded_shape, output.dim_order)
dtype = (
ts.dtype_str_to_val("INT8")
if is_quant_node
else ts.dtype_str_to_val("FP32")
)

rescale_out = tosa_graph.addIntermediate(expanded_shape, dtype)
rescale_attr = ts.TosaSerializerAttribute()
rescale_attr.ReshapeAttribute(expanded_shape)
tosa_graph.addOperator(
TosaOp.Op().RESHAPE, [item_name], [rescale_out.name], rescale_attr
)
item_name = rescale_out.name

attr = ts.TosaSerializerAttribute()
attr.TileAttribute(tosa_shape(multiples, output.dim_order))
tosa_graph.addOperator(TosaOp.Op().TILE, [item_name], [output.name], attr)
tosa_graph.addOperator(TosaOp.Op().TILE, [inputs[0].name], [output.name], attr)
11 changes: 10 additions & 1 deletion backends/arm/test/ops/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Repeat(torch.nn.Module):
(torch.randn(3), (2, 2)),
(torch.randn(3), (1, 2, 3)),
(torch.randn((3, 3)), (2, 2, 2)),
(torch.randn((3, 3, 3)), (2, 1, 2, 4)),
]

def forward(self, x: torch.Tensor, multiples: Sequence):
Expand Down Expand Up @@ -106,12 +107,20 @@ def test_repeat_tosa_MI(self, test_input, multiples):
def test_repeat_tosa_BI(self, test_input, multiples):
self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples))

@parameterized.expand(Repeat.test_parameters)
@parameterized.expand(Repeat.test_parameters[:-1])
def test_repeat_u55_BI(self, test_input, multiples):
self._test_repeat_ethosu_pipeline(
common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples)
)

# Final test requires transpose which is not supported on u55.
@parameterized.expand(Repeat.test_parameters[-1:])
@unittest.expectedFailure
def test_repeat_u55_BI_xfails(self, test_input, multiples):
self._test_repeat_ethosu_pipeline(
common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples)
)

@parameterized.expand(Repeat.test_parameters)
def test_repeat_u85_BI(self, test_input, multiples):
self._test_repeat_ethosu_pipeline(
Expand Down
74 changes: 74 additions & 0 deletions backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest

import torch
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
UnsqueezeBeforeRepeatPass,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.xnnpack.test.tester.tester import RunPasses


class Repeat(torch.nn.Module):
"""
Basic repeat model.
"""

def forward(self, x: torch.Tensor):
return x.repeat(2, 2, 2, 2)


class TestUnsqueezeBeforeRepeatPass(unittest.TestCase):
def test_tosa_MI_insert_view(self):
"""
When rank(input) != number of repeated dimensions (=4 in Repeat module),
insert view.
"""
module = Repeat()
inputs = (torch.rand((2, 3, 4)),)
test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass])
(
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
)
.export()
.to_edge()
.check(["aten_repeat_default"])
.check_not(["aten_view_copy_default"])
.run_passes(test_pass_stage)
.check(["aten_repeat_default", "aten_view_copy_default"])
)
)

def test_tosa_MI_dont_insert_view(self):
"""
When rank(input) == number of repeated dimensions (=4 in Repeat module),
DON'T insert view.
"""
module = Repeat()
inputs = (torch.rand((2, 3, 4, 1)),)
test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass])
(
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
)
.export()
.to_edge()
.check(["aten_repeat_default"])
.check_not(["aten_view_copy_default"])
.run_passes(test_pass_stage)
.check(["aten_repeat_default"])
.check_not(["aten_view_copy_default"])
)
)

0 comments on commit 4f47cc9

Please sign in to comment.