Skip to content

Commit

Permalink
Arm backend: Fix bug in ConvertExpandCopyToRepeatPass
Browse files Browse the repository at this point in the history
In the ConvertExpandCopyToRepeatPass the arguments for the repeat operation are formed incorrectly.
For the torch.Tensor.expand operation passing -1 as the size for a dimension means that the size of that dimension does not change.
For the DeiT-tiny case, torch.ones(1, 1, 192).expand(1, -1, -1) the pass will prepare arguments to the repeat operation as [1, -1, 1] which will cause an error, in this case the arguments should be [1, 1, 1].
  • Loading branch information
Aleksei-grovety committed Dec 5, 2024
1 parent ac8bf78 commit 62a2493
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
6 changes: 4 additions & 2 deletions backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def call_operator(self, op, args, kwargs, meta):
]

# To convert expand arg to repeat arg, non-repeated dims should have
# multiples[dim] = 1.
# multiples[dim] = 1. Passing -1 to expand arg means
# not changing the size of that dimension.
multiples = [
multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank)
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
for i in range(expanded_rank)
]
return super().call_operator(
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
Expand Down
1 change: 1 addition & 0 deletions backends/arm/test/ops/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Expand(torch.nn.Module):
(torch.ones(1, 1, 2, 2), (4, 3, -1, 2)),
(torch.ones(1), (2, 2, 4)),
(torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)),
(torch.ones(1, 1, 192), (1, -1, -1)),
]

def forward(self, x: torch.Tensor, multiples: Sequence):
Expand Down

0 comments on commit 62a2493

Please sign in to comment.