-
Notifications
You must be signed in to change notification settings - Fork 379
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
In addition to moving logic from node visitor, this also fixes repeating a rank 3 tensor to make a rank 3 tensor. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I7090159bce47b6aa4d6613bbeb2d681d5cfcb193
- Loading branch information
1 parent
44bcfc3
commit 4f47cc9
Showing
5 changed files
with
151 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# pyre-unsafe | ||
import torch | ||
import torch.fx | ||
from executorch.backends.arm._passes.arm_pass_utils import ( | ||
create_node, | ||
get_first_fake_tensor, | ||
) | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
|
||
class UnsqueezeBeforeRepeatPass(ExportPass): | ||
""" | ||
A TOSA TILE op only supports rank(in) == rank(out). | ||
To support Pytorch's repeat which can also add dimensions, | ||
we add an explicit view op before which adds the new dimensions. | ||
New dimensions are appendend at the front, see | ||
https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html | ||
Original: | ||
repeat(multiples) | ||
After pass: | ||
view(shape = [1]*num_new_dims + old_shape) | ||
repeat(multiples) | ||
""" | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
modified_graph = False | ||
for node in graph_module.graph.nodes: | ||
if node.op != "call_function": | ||
continue | ||
if node.target != exir_ops.edge.aten.repeat.default: | ||
continue | ||
|
||
old_shape = list(get_first_fake_tensor(node.all_input_nodes[0]).shape) | ||
old_rank = len(old_shape) | ||
multiples = node.args[1] | ||
new_rank = len(multiples) | ||
if old_rank == new_rank: | ||
continue | ||
|
||
num_new_dims = new_rank - old_rank | ||
new_shape = [1] * num_new_dims + old_shape | ||
|
||
with graph_module.graph.inserting_before(node): | ||
view_node = create_node( | ||
graph_module.graph, | ||
exir_ops.edge.aten.view_copy.default, | ||
(node.all_input_nodes[0], new_shape), | ||
) | ||
node.replace_input_with(node.all_input_nodes[0], view_node) | ||
modified_graph = True | ||
|
||
if modified_graph: | ||
graph_module.recompile() | ||
graph_module = super().call(graph_module).graph_module | ||
return PassResult(graph_module, modified_graph) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import unittest | ||
|
||
import torch | ||
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( | ||
UnsqueezeBeforeRepeatPass, | ||
) | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.backends.xnnpack.test.tester.tester import RunPasses | ||
|
||
|
||
class Repeat(torch.nn.Module): | ||
""" | ||
Basic repeat model. | ||
""" | ||
|
||
def forward(self, x: torch.Tensor): | ||
return x.repeat(2, 2, 2, 2) | ||
|
||
|
||
class TestUnsqueezeBeforeRepeatPass(unittest.TestCase): | ||
def test_tosa_MI_insert_view(self): | ||
""" | ||
When rank(input) != number of repeated dimensions (=4 in Repeat module), | ||
insert view. | ||
""" | ||
module = Repeat() | ||
inputs = (torch.rand((2, 3, 4)),) | ||
test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass]) | ||
( | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=inputs, | ||
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), | ||
) | ||
.export() | ||
.to_edge() | ||
.check(["aten_repeat_default"]) | ||
.check_not(["aten_view_copy_default"]) | ||
.run_passes(test_pass_stage) | ||
.check(["aten_repeat_default", "aten_view_copy_default"]) | ||
) | ||
) | ||
|
||
def test_tosa_MI_dont_insert_view(self): | ||
""" | ||
When rank(input) == number of repeated dimensions (=4 in Repeat module), | ||
DON'T insert view. | ||
""" | ||
module = Repeat() | ||
inputs = (torch.rand((2, 3, 4, 1)),) | ||
test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass]) | ||
( | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=inputs, | ||
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), | ||
) | ||
.export() | ||
.to_edge() | ||
.check(["aten_repeat_default"]) | ||
.check_not(["aten_view_copy_default"]) | ||
.run_passes(test_pass_stage) | ||
.check(["aten_repeat_default"]) | ||
.check_not(["aten_view_copy_default"]) | ||
) | ||
) |