Skip to content

Commit

Permalink
Dynamic Conv1d + W2L (#2976)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2976

Conv1d uses static reshape operator, in order to convert 3d tensor to 4d tensor so xnnpack can operate using conv2d.

For dynamism, reshape only accepts a single dynamic dimension, which is denoted as dynamic with a dim of 0.

Reviewed By: digantdesai, kirklandsign

Differential Revision: D55815092
  • Loading branch information
mcr229 authored and facebook-github-bot committed Apr 11, 2024
1 parent 8f391a4 commit add2311
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 19 deletions.
32 changes: 30 additions & 2 deletions backends/xnnpack/operators/op_squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,21 @@ def define_node(
"val" in input_node.meta,
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
)
new_shape = get_shape(input_node)[:-1]
dynamic_shape = node.meta["val"].shape
new_shape = []

num_dynamic_dims = 0
for dim in dynamic_shape:
if isinstance(dim, torch.SymInt):
num_dynamic_dims += 1
new_shape.append(0)
else:
new_shape.append(dim)

check_or_raise(
num_dynamic_dims <= 1,
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
)

ser_node = XNode(
xnode_union=XNNStaticReshape(
Expand Down Expand Up @@ -101,7 +115,21 @@ def define_node(
"val" in input_node.meta,
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
)
new_shape = get_shape(input_node) + [1]
dynamic_shape = node.meta["val"].shape
new_shape = []

num_dynamic_dims = 0
for dim in dynamic_shape:
if isinstance(dim, torch.SymInt):
num_dynamic_dims += 1
new_shape.append(0)
else:
new_shape.append(dim)

check_or_raise(
num_dynamic_dims <= 1,
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
)

ser_node = XNode(
xnode_union=XNNStaticReshape(
Expand Down
12 changes: 7 additions & 5 deletions backends/xnnpack/test/models/w2l.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ class TestW2L(unittest.TestCase):
batch_size = 10
input_frames = 700
vocab_size = 4096
num_features = 1
wav2letter = models.Wav2Letter(num_classes=vocab_size).eval()

model_inputs = (torch.randn(batch_size, 1, input_frames),)
model_inputs = (torch.randn(batch_size, num_features, input_frames),)
dynamic_shape = ({0: torch.export.Dim("batch", min=2, max=10)},)

def test_fp32_w2l(self):
(
Tester(self.wav2letter, self.model_inputs)
Tester(self.wav2letter, self.model_inputs, self.dynamic_shape)
.export()
.to_edge()
.partition()
Expand All @@ -34,12 +36,12 @@ def test_fp32_w2l(self):
.check(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
.run_method_and_compare_outputs(num_runs=10)
)

def test_qs8_w2l(self):
(
Tester(self.wav2letter.eval(), self.model_inputs)
Tester(self.wav2letter.eval(), self.model_inputs, self.dynamic_shape)
.quantize()
.export()
.to_edge()
Expand All @@ -53,5 +55,5 @@ def test_qs8_w2l(self):
.check(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
.run_method_and_compare_outputs(num_runs=10)
)
50 changes: 38 additions & 12 deletions backends/xnnpack/test/ops/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,15 @@ def forward(self, x):
z = torch.add(y, z)
return z

def _test_conv1d(self, module, inputs, conv_count, quantized=False):
def _test_conv1d(
self, module, inputs, conv_count, quantized=False, dynamic_shape=None
):
(
(Tester(module, inputs).quantize() if quantized else Tester(module, inputs))
(
Tester(module, inputs, dynamic_shape).quantize()
if quantized
else Tester(module, inputs)
)
.export()
.check_count({"torch.ops.aten.convolution.default": conv_count})
.to_edge()
Expand All @@ -101,21 +107,41 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False):
)

def test_fp16_conv1d(self):
inputs = (torch.randn(1, 2, 4).to(torch.float16),)
self._test_conv1d(self.Conv1d(dtype=torch.float16), inputs, conv_count=1)
inputs = (torch.randn(2, 2, 4).to(torch.float16),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(
self.Conv1d(dtype=torch.float16),
inputs,
conv_count=1,
dynamic_shape=dynamic_shapes,
)

def test_fp32_conv1d(self):
inputs = (torch.randn(1, 2, 4),)
self._test_conv1d(self.Conv1d(), inputs, 1)
inputs = (torch.randn(2, 2, 4),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(self.Conv1d(), inputs, 1, dynamic_shape=dynamic_shapes)

def test_fp32_conv1d_batchnorm_seq(self):
inputs = (torch.randn(1, 2, 4),)
self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2)
inputs = (torch.randn(2, 2, 4),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(
self.Conv1dBatchNormSequential(), inputs, 2, dynamic_shape=dynamic_shapes
)

def test_qs8_conv1d(self):
inputs = (torch.randn(1, 2, 4),)
self._test_conv1d(self.Conv1d(), inputs, 1, quantized=True)
inputs = (torch.randn(2, 2, 4),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(
self.Conv1d(), inputs, 1, quantized=True, dynamic_shape=dynamic_shapes
)

def test_qs8_conv1d_batchnorm_seq(self):
inputs = (torch.randn(1, 2, 4),)
self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2, quantized=True)
inputs = (torch.randn(2, 2, 4),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(
self.Conv1dBatchNormSequential(),
inputs,
2,
quantized=True,
dynamic_shape=dynamic_shapes,
)

0 comments on commit add2311

Please sign in to comment.