Skip to content

Commit

Permalink
Dynamic ViT
Browse files Browse the repository at this point in the history
Differential Revision: D54972681
  • Loading branch information
mcr229 authored and facebook-github-bot committed Apr 11, 2024
1 parent 8d210d0 commit 8f391a4
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions backends/xnnpack/test/models/torchvision_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ class TestViT(unittest.TestCase):
vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
vit = vit.eval()
model_inputs = (torch.ones(1, 3, 224, 224),)
dynamic_shapes = (
{
2: torch.export.Dim("height", min=224, max=455),
3: torch.export.Dim("width", min=224, max=455),
},
)

class DynamicViT(torch.nn.Module):
def __init__(self):
super().__init__()
self.vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
self.vit = self.vit.eval()

def forward(self, x):
x = torch.nn.functional.interpolate(
x,
size=(224, 224),
mode="bilinear",
align_corners=True,
antialias=False,
)
return self.vit(x)

all_operators = {
"executorch_exir_dialects_edge__ops_aten_expand_copy_default",
"executorch_exir_dialects_edge__ops_aten_cat_default",
Expand All @@ -34,7 +57,8 @@ class TestViT(unittest.TestCase):
"executorch_exir_dialects_edge__ops_aten_bmm_default",
}

def test_fp32_vit(self):
def _test_exported_vit(self, tester, check_nots=None):
check_nots = check_nots or []
lowerable_xnn_operators = self.all_operators - {
"executorch_exir_dialects_edge__ops_aten_expand_copy_default",
"executorch_exir_dialects_edge__ops_aten_gelu_default",
Expand All @@ -48,14 +72,33 @@ def test_fp32_vit(self):
"executorch_exir_dialects_edge__ops_aten_bmm_default",
}
(
Tester(self.vit, self.model_inputs)
.export()
tester.export()
.to_edge()
.check(list(self.all_operators))
.partition()
.check(["torch.ops.higher_order.executorch_call_delegate"])
.check_not(list(lowerable_xnn_operators))
.check_not(check_nots)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_fp32_vit(self):
self._test_exported_vit(Tester(self.vit, self.model_inputs))

def test_dynamic_vit(self):
bilinear_ops = {
"executorch_exir_dialects_edge__ops_aten_sub_Tensor",
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
"executorch_exir_dialects_edge__ops_aten_index_Tensor",
"executorch_exir_dialects_edge__ops_aten_arange_start_step",
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
"executorch_exir_dialects_edge__ops_aten_clamp_default",
}

self._test_exported_vit(
Tester(self.DynamicViT(), self.model_inputs, self.dynamic_shapes),
bilinear_ops,
)

0 comments on commit 8f391a4

Please sign in to comment.