diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index 666ee23aa3..4f0ba3bd1a 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -25,7 +25,7 @@ import torchvision.models as models from torch.export import export, ExportedProgram from torchvision.models.mobilenetv2 import MobileNet_V2_Weights from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner -from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge +from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge_transform_and_lower from executorch.exir.backend.backend_api import to_backend @@ -33,9 +33,10 @@ mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFA sample_inputs = (torch.randn(1, 3, 224, 224), ) exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs) -edge: EdgeProgramManager = to_edge(exported_program) - -edge = edge.to_backend(XnnpackPartitioner()) +edge: EdgeProgramManager = to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()], +) ``` We will go through this example with the [MobileNetV2](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/) pretrained model downloaded from the TorchVision library. The flow of lowering a model starts after exporting the model `to_edge`. We call the `to_backend` api with the `XnnpackPartitioner`. The partitioner identifies the subgraphs suitable for XNNPACK backend delegate to consume. Afterwards, the identified subgraphs will be serialized with the XNNPACK Delegate flatbuffer schema and each subgraph will be replaced with a call to the XNNPACK Delegate. @@ -47,16 +48,18 @@ GraphModule( (lowered_module_1): LoweredBackendModule() ) -def forward(self, arg314_1): + + +def forward(self, b_features_0_1_num_batches_tracked, ..., x): lowered_module_0 = self.lowered_module_0 - executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, arg314_1); lowered_module_0 = arg314_1 = None - getitem = executorch_call_delegate[0]; executorch_call_delegate = None - aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem, [1, 1280]); getitem = None - aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default); aten_view_copy_default = None lowered_module_1 = self.lowered_module_1 - executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, aten_clone_default); lowered_module_1 = aten_clone_default = None - getitem_1 = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None - return (getitem_1,) + executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, x); lowered_module_1 = x = None + getitem_53 = executorch_call_delegate_1[0]; executorch_call_delegate_1 = None + aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(getitem_53, [1, 1280]); getitem_53 = None + aten_clone_default = executorch_exir_dialects_edge__ops_aten_clone_default(aten_view_copy_default); aten_view_copy_default = None + executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, aten_clone_default); lowered_module_0 = aten_clone_default = None + getitem_52 = executorch_call_delegate[0]; executorch_call_delegate = None + return (getitem_52,) ``` We print the graph after lowering above to show the new nodes that were inserted to call the XNNPACK Delegate. The subgraphs which are being delegated to XNNPACK are the first argument at each call site. It can be observed that the majority of `convolution-relu-add` blocks and `linear` blocks were able to be delegated to XNNPACK. We can also see the operators which were not able to be lowered to the XNNPACK delegate, such as `clone` and `view_copy`. @@ -75,7 +78,7 @@ The XNNPACK delegate can also execute symmetrically quantized models. To underst ```python from torch.export import export_for_training -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() sample_inputs = (torch.randn(1, 3, 224, 224), ) @@ -111,9 +114,11 @@ Quantization requires a two stage export. First we use the `export_for_training` ```python # Continued from earlier... -edge = to_edge(export(quantized_mobilenetv2, sample_inputs), compile_config=EdgeCompileConfig(_check_ir_validity=False)) - -edge = edge.to_backend(XnnpackPartitioner()) +edge = to_edge_transform_and_lower( + export(quantized_mobilenetv2, sample_inputs), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + partitioner=[XnnpackPartitioner()] +) exec_prog = edge.to_executorch()