diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 7c376609d8..71dbdb5b85 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -36,9 +36,11 @@ def call_operator(self, op, args, kwargs, meta): ] # To convert expand arg to repeat arg, non-repeated dims should have - # multiples[dim] = 1. + # multiples[dim] = 1. Passing -1 to expand arg means + # not changing the size of that dimension. multiples = [ - multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank) + multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 + for i in range(expanded_rank) ] return super().call_operator( op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 05f72aa379..a8cdd48b40 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -36,6 +36,7 @@ class Expand(torch.nn.Module): (torch.ones(1, 1, 2, 2), (4, 3, -1, 2)), (torch.ones(1), (2, 2, 4)), (torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)), + (torch.ones(1, 1, 192), (1, -1, -1)), ] def forward(self, x: torch.Tensor, multiples: Sequence):