From 4d8a7e491e430dca93bf103129af2cd66edc7907 Mon Sep 17 00:00:00 2001 From: Aleksei-grovety <113356454+Aleksei-grovety@users.noreply.github.com> Date: Thu, 5 Dec 2024 16:06:09 +0400 Subject: [PATCH] Arm backend: Fix bug in ConvertExpandCopyToRepeatPass In the ConvertExpandCopyToRepeatPass the arguments for the repeat operation are formed incorrectly. For the torch.Tensor.expand operation passing -1 as the size for a dimension means that the size of that dimension does not change. For the DeiT-tiny case, torch.ones(1, 1, 192).expand(1, -1, -1) the pass will prepare arguments to the repeat operation as [1, -1, 1] which will cause an error, in this case the arguments should be [1, 1, 1]. --- backends/arm/_passes/convert_expand_copy_to_repeat.py | 6 ++++-- backends/arm/test/ops/test_expand.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 7c376609d8..71dbdb5b85 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -36,9 +36,11 @@ def call_operator(self, op, args, kwargs, meta): ] # To convert expand arg to repeat arg, non-repeated dims should have - # multiples[dim] = 1. + # multiples[dim] = 1. Passing -1 to expand arg means + # not changing the size of that dimension. multiples = [ - multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank) + multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 + for i in range(expanded_rank) ] return super().call_operator( op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 05f72aa379..a8cdd48b40 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -36,6 +36,7 @@ class Expand(torch.nn.Module): (torch.ones(1, 1, 2, 2), (4, 3, -1, 2)), (torch.ones(1), (2, 2, 4)), (torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)), + (torch.ones(1, 1, 192), (1, -1, -1)), ] def forward(self, x: torch.Tensor, multiples: Sequence):