From 63ff1aefedaf4c2cef4cb48ec48f65201fff740d Mon Sep 17 00:00:00 2001 From: Zingo Andersen Date: Tue, 5 Nov 2024 18:02:05 +0100 Subject: [PATCH 01/59] Arm backend: Use better Ethos-U PMU counters for Ethos-U85 Differential Revision: D65147935 Pull Request resolved: https://github.com/pytorch/executorch/pull/6455 --- backends/arm/third-party/ethos-u-core-driver | 2 +- .../arm/executor_runner/arm_perf_monitor.cpp | 31 ++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/backends/arm/third-party/ethos-u-core-driver b/backends/arm/third-party/ethos-u-core-driver index 90f9df900a..78df0006c5 160000 --- a/backends/arm/third-party/ethos-u-core-driver +++ b/backends/arm/third-party/ethos-u-core-driver @@ -1 +1 @@ -Subproject commit 90f9df900acdc0718ecd2dfdc53780664758dec5 +Subproject commit 78df0006c5fa667150d3ee35db7bde1d3f6f58c7 diff --git a/examples/arm/executor_runner/arm_perf_monitor.cpp b/examples/arm/executor_runner/arm_perf_monitor.cpp index 323010bfd7..b75e510d9d 100644 --- a/examples/arm/executor_runner/arm_perf_monitor.cpp +++ b/examples/arm/executor_runner/arm_perf_monitor.cpp @@ -24,7 +24,14 @@ static std::vector ethosu_pmuEventCounts( ETHOSU_PMU_Get_NumEventCounters(), 0); +#if defined(ETHOSU55) || defined(ETHOSU65) static const uint32_t ethosu_pmuCountersUsed = 4; +#elif defined(ETHOSU85) +static const uint32_t ethosu_pmuCountersUsed = 5; +#else +#error No NPU target defined +#endif + // ethosu_pmuCountersUsed should match numbers of counters setup in // ethosu_inference_begin() and not be more then the HW supports static_assert(ETHOSU_PMU_NCOUNTERS >= ethosu_pmuCountersUsed); @@ -44,18 +51,26 @@ void ethosu_inference_begin(struct ethosu_driver* drv, void*) { ETHOSU_PMU_Set_EVTYPER(drv, 1, ETHOSU_PMU_AXI1_RD_DATA_BEAT_RECEIVED); ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_AXI0_WR_DATA_BEAT_WRITTEN); ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_NPU_IDLE); + // Enable the 4 counters + ETHOSU_PMU_CNTR_Enable( + drv, + ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CNT2_Msk | ETHOSU_PMU_CNT3_Msk | + ETHOSU_PMU_CNT4_Msk); #elif defined(ETHOSU85) - ETHOSU_PMU_Set_EVTYPER(drv, 0, ETHOSU_PMU_EXT0_RD_DATA_BEAT_RECEIVED); - ETHOSU_PMU_Set_EVTYPER(drv, 1, ETHOSU_PMU_EXT1_RD_DATA_BEAT_RECEIVED); - ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_EXT0_WR_DATA_BEAT_WRITTEN); - ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_NPU_IDLE); + ETHOSU_PMU_Set_EVTYPER(drv, 0, ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED); + ETHOSU_PMU_Set_EVTYPER(drv, 1, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN); + ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED); + ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN); + ETHOSU_PMU_Set_EVTYPER(drv, 4, ETHOSU_PMU_NPU_IDLE); + // Enable the 5 counters + ETHOSU_PMU_CNTR_Enable( + drv, + ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CNT2_Msk | ETHOSU_PMU_CNT3_Msk | + ETHOSU_PMU_CNT4_Msk | ETHOSU_PMU_CNT5_Msk); #else #error No NPU target defined #endif - // Enable 4 counters - ETHOSU_PMU_CNTR_Enable(drv, 0xf); - ETHOSU_PMU_CNTR_Enable(drv, ETHOSU_PMU_CCNT_Msk); ETHOSU_PMU_CYCCNT_Reset(drv); @@ -177,7 +192,7 @@ void StopMeasurements() { #elif defined(ETHOSU85) ET_LOG( Info, - "Ethos-U PMU Events:[ETHOSU_PMU_EXT0_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT1_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT0_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE]"); + "Ethos-U PMU Events:[ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE]"); #else #error No NPU target defined #endif From 363505f968d80eac0226330df680bdef43ee1ff6 Mon Sep 17 00:00:00 2001 From: azad-meta <148276886+azad-meta@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:43:25 -0500 Subject: [PATCH 02/59] adding suppression tags to improve autodeps noise Differential Revision: D65401749 Pull Request resolved: https://github.com/pytorch/executorch/pull/6630 --- backends/arm/TARGETS | 1 + backends/arm/operators/TARGETS | 1 + 2 files changed, 2 insertions(+) diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index 39910f0150..0dc8797be5 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -1,3 +1,4 @@ +# @noautodeps load("@fbcode_macros//build_defs:python_library.bzl", "python_library") python_library( diff --git a/backends/arm/operators/TARGETS b/backends/arm/operators/TARGETS index fd04d5fb84..c2aa8d2dfb 100644 --- a/backends/arm/operators/TARGETS +++ b/backends/arm/operators/TARGETS @@ -1,3 +1,4 @@ +# @noautodeps load("@fbcode_macros//build_defs:python_library.bzl", "python_library") python_library( From d99d26e7e1052ad01a634b7f4337fdbc40f013e4 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 5 Nov 2024 12:52:40 -0500 Subject: [PATCH 03/59] c10::optional -> std::optional Differential Revision: D65439045 Pull Request resolved: https://github.com/pytorch/executorch/pull/6642 --- kernels/quantized/cpu/embeddingxb.cpp | 12 ++++++------ kernels/quantized/cpu/embeddingxb.h | 8 ++++---- kernels/quantized/cpu/op_dequantize.cpp | 6 +++--- kernels/quantized/cpu/op_embedding.cpp | 12 ++++++------ kernels/quantized/cpu/op_embedding2b.cpp | 8 ++++---- kernels/quantized/cpu/op_embedding4b.cpp | 8 ++++---- kernels/quantized/cpu/op_mixed_linear.cpp | 12 ++++++------ kernels/quantized/cpu/op_mixed_mm.cpp | 6 +++--- 8 files changed, 36 insertions(+), 36 deletions(-) diff --git a/kernels/quantized/cpu/embeddingxb.cpp b/kernels/quantized/cpu/embeddingxb.cpp index f8fdfe078c..5275f842df 100644 --- a/kernels/quantized/cpu/embeddingxb.cpp +++ b/kernels/quantized/cpu/embeddingxb.cpp @@ -65,7 +65,7 @@ static inline int32_t get_embedding_dim( void check_embedding_xbit_args( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -170,7 +170,7 @@ template void embedding_xbit_per_channel( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const Tensor& indices, Tensor& out, int weight_nbit) { @@ -260,7 +260,7 @@ Tensor& quantized_embedding_xbit_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -299,7 +299,7 @@ Tensor& quantized_embedding_xbit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -325,7 +325,7 @@ Tensor& quantized_embedding_xbit_dtype_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -368,7 +368,7 @@ Tensor& quantized_embedding_xbit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/embeddingxb.h b/kernels/quantized/cpu/embeddingxb.h index ae1fccc6c2..d08c8ae745 100644 --- a/kernels/quantized/cpu/embeddingxb.h +++ b/kernels/quantized/cpu/embeddingxb.h @@ -24,7 +24,7 @@ Tensor& quantized_embedding_xbit_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -35,7 +35,7 @@ Tensor& quantized_embedding_xbit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -47,7 +47,7 @@ Tensor& quantized_embedding_xbit_dtype_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -59,7 +59,7 @@ Tensor& quantized_embedding_xbit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 9f8a365b9c..8d73d06694 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -186,7 +186,7 @@ float get_scale(const Tensor& scale, size_t channel_ix) { Tensor& dequantize_per_channel_out( const Tensor& input, const Tensor& scale, - const optional& opt_zero_points, + const exec_aten::optional& opt_zero_points, int64_t axis, int64_t quant_min, int64_t quant_max, @@ -261,7 +261,7 @@ Tensor& dequantize_per_channel_out( const auto* input_data_ptr = input.const_data_ptr(); \ ET_CHECK_MSG( \ axis == 0, "Axis must be 0 for a single dimensional tensors"); \ - const optional dim; \ + const exec_aten::optional dim; \ apply_over_dim( \ [input_data_ptr, out_data_ptr, zero_point_data, &scale]( \ size_t numel, size_t stride, size_t base_ix) { \ @@ -331,7 +331,7 @@ Tensor& dequantize_per_channel_out( KernelRuntimeContext& context, const Tensor& input, const Tensor& scale, - const optional& opt_zero_points, + const exec_aten::optional& opt_zero_points, int64_t axis, int64_t quant_min, int64_t quant_max, diff --git a/kernels/quantized/cpu/op_embedding.cpp b/kernels/quantized/cpu/op_embedding.cpp index e48e9a7eea..0ffe363f2a 100644 --- a/kernels/quantized/cpu/op_embedding.cpp +++ b/kernels/quantized/cpu/op_embedding.cpp @@ -27,7 +27,7 @@ namespace { void check_embedding_byte_args( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -129,7 +129,7 @@ template void embedding_byte_per_channel( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const Tensor& indices, Tensor& out) { // An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a @@ -218,7 +218,7 @@ Tensor& quantized_embedding_byte_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -253,7 +253,7 @@ Tensor& quantized_embedding_byte_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -277,7 +277,7 @@ Tensor& quantized_embedding_byte_dtype_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -316,7 +316,7 @@ Tensor& quantized_embedding_byte_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/op_embedding2b.cpp b/kernels/quantized/cpu/op_embedding2b.cpp index 0fdd7b731f..a2d2f8eb39 100644 --- a/kernels/quantized/cpu/op_embedding2b.cpp +++ b/kernels/quantized/cpu/op_embedding2b.cpp @@ -37,7 +37,7 @@ Tensor& quantized_embedding_2bit_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -57,7 +57,7 @@ Tensor& quantized_embedding_2bit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -77,7 +77,7 @@ Tensor& quantized_embedding_2bit_out( Tensor& quantized_embedding_2bit_dtype_out( const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -99,7 +99,7 @@ Tensor& quantized_embedding_2bit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/op_embedding4b.cpp b/kernels/quantized/cpu/op_embedding4b.cpp index 8a99073cd0..d123b40b35 100644 --- a/kernels/quantized/cpu/op_embedding4b.cpp +++ b/kernels/quantized/cpu/op_embedding4b.cpp @@ -37,7 +37,7 @@ Tensor& quantized_embedding_4bit_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -57,7 +57,7 @@ Tensor& quantized_embedding_4bit_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, @@ -79,7 +79,7 @@ Tensor& quantized_embedding_4bit_dtype_out( // non quant input and returns fp output const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, const int64_t weight_quant_min, const int64_t weight_quant_max, const Tensor& indices, @@ -101,7 +101,7 @@ Tensor& quantized_embedding_4bit_dtype_out( KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, int64_t weight_quant_min, int64_t weight_quant_max, const Tensor& indices, diff --git a/kernels/quantized/cpu/op_mixed_linear.cpp b/kernels/quantized/cpu/op_mixed_linear.cpp index d3552e1ca6..af3d10cedb 100644 --- a/kernels/quantized/cpu/op_mixed_linear.cpp +++ b/kernels/quantized/cpu/op_mixed_linear.cpp @@ -19,8 +19,8 @@ bool check_quantized_mixed_linear_args( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, - const optional dtype, + const exec_aten::optional& opt_weight_zero_points, + const exec_aten::optional dtype, Tensor& out) { ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2)); ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2)); @@ -64,8 +64,8 @@ Tensor& quantized_mixed_linear_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, - const optional dtype, + const exec_aten::optional& opt_weight_zero_points, + const exec_aten::optional dtype, Tensor& out) { // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available. ET_CHECK(check_quantized_mixed_linear_args( @@ -117,8 +117,8 @@ Tensor& quantized_mixed_linear_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, - const optional dtype, + const exec_aten::optional& opt_weight_zero_points, + const exec_aten::optional dtype, Tensor& out) { // TODO(mcandales): Remove the need for this wrapper // TODO(mkg): add support for dtype diff --git a/kernels/quantized/cpu/op_mixed_mm.cpp b/kernels/quantized/cpu/op_mixed_mm.cpp index 895c7e0af3..18d8f1e70d 100644 --- a/kernels/quantized/cpu/op_mixed_mm.cpp +++ b/kernels/quantized/cpu/op_mixed_mm.cpp @@ -19,7 +19,7 @@ bool check_quantized_mixed_mm_args( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, Tensor& out) { ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2)); ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2)); @@ -55,7 +55,7 @@ Tensor& quantized_mixed_mm_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, Tensor& out) { ET_CHECK(check_quantized_mixed_mm_args( in, weight, weight_scales, opt_weight_zero_points, out)); @@ -92,7 +92,7 @@ Tensor& quantized_mixed_mm_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, - const optional& opt_weight_zero_points, + const exec_aten::optional& opt_weight_zero_points, Tensor& out) { // TODO(mcandales): Remove the need for this wrapper (void)ctx; From cefe51594f717e8c9d089d4b0a55de7dc9479c56 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 5 Nov 2024 12:18:45 -0800 Subject: [PATCH 04/59] [ET-VK] Refine paritioner to account for storage type and memory layout (#6668) Pull Request resolved: https://github.com/pytorch/executorch/pull/6635 ## Context There are a variety of ways that tensors can be represented in Vulkan. The two main descriptors for how a tensor is laid out in memory is: 1. Storage Type (buffer or texture) 2. Memory Layout (which dim is packed along a texel, which dim has a stride of 1, etc.) Due to the differences between buffers and textures, and the differences between different memory layouts, an implementation for an operator may only support a specific set of (storage type, memory layout) combinations. Furthermore, if an operator implementation supports multiple (storage type, memory layout) combinations, there may be a "preferred" setting which results in optimal performance. These changes lay the foundation for the implementation of a memory metadata tagging graph transform, which will make sure that all tensors participating in an operator call is has a valid/optimal (storage type, memory layout) setting, and insert transition operators to transfer input tensors to the correct memory settings when necessary. An additional change that is required arises from the fact that in Vulkan, there is a limit on texture and buffer sizes. Therefore, the partitioner needs to account for the storage types and memory layouts supported by the operator implementation, and check if all tensors participating in a computation can be represented with some storage type, memory layout combination supported by the implementation. ## Changes Improvements to the operator registry: * Introduce utility functions to check the optimal and enabled storage types and memory layouts for an operator Improvements to the Partitioner: * Account for the storage types and memory layouts supported by an operator when deciding if a node should be partitioned * Improved logic for fusable ops (i.e. the permute/transpose before a mm which can be fused into linear) to check if the final target op is supported in Vulkan, and only partition those nodes if so. Otherwise, don't partition it so that it can be fused by another backend. ghstack-source-id: 251883705 @exported-using-ghexport Differential Revision: [D65428843](https://our.internmc.facebook.com/intern/diff/D65428843/) Co-authored-by: Stephen Jia --- backends/vulkan/op_registry.py | 193 ++++++++++++--- backends/vulkan/partitioner/TARGETS | 1 + .../vulkan/partitioner/vulkan_partitioner.py | 229 ++++++++++++------ backends/vulkan/targets.bzl | 1 + backends/vulkan/utils.py | 142 +++++++++++ 5 files changed, 455 insertions(+), 111 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index fe67fdb30c..3a6191bccb 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -8,18 +8,31 @@ import operator -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, Optional, Set, Union import executorch.backends.vulkan.custom_ops_lib # noqa import torch -from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) + +from executorch.backends.vulkan.utils import ( + all_memory_layouts, + all_packed_dims, + PackedDim, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor +###################### +## OpFeatures class ## +###################### + def allow_node(node: torch.fx.Node) -> bool: return True @@ -27,25 +40,37 @@ def allow_node(node: torch.fx.Node) -> bool: class TextureImplFeatures: __slots__ = [ - # Indicates if the compute shader is agnostic to the packed dimension - "uses_packed_dim", - # Indicates if the compute shader is agnostic to the texture axis mapping + "valid_packed_dims", "uses_axis_map", - # Specifies a specific set of memory layouts that the shader supports. If it is - # and empty list, then the supported memory layouts can be inferred from the - # `uses_packed_dim` and `uses_axis_map` flags. - "supported_layouts", ] def __init__( self, - uses_packed_dim: bool = False, uses_axis_map: bool = False, - supported_layouts: Optional[List[VkMemoryLayout]] = None, + valid_packed_dims: Optional[Set[PackedDim]] = None, ): - self.uses_packed_dim: bool = uses_packed_dim self.uses_axis_map: bool = uses_axis_map - self.supported_layouts: Optional[List[VkMemoryLayout]] = supported_layouts + self.valid_packed_dims = set() + if valid_packed_dims is not None: + self.valid_packed_dims = valid_packed_dims + + def valid_memory_layouts(self) -> Set[VkMemoryLayout]: + """ + Derive the set of memory layouts supported by the texture implementation based + on the valid packed dimensions. + """ + layouts = set() + + if PackedDim.WIDTH in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED) + + if PackedDim.HEIGHT in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) + + if PackedDim.CHANNELS in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED) + + return layouts class OpFeatures: @@ -58,6 +83,9 @@ class OpFeatures: # bool indicating if the operator has a resize function, which allows it to # support dynamic shape tensors. "resize_fn", + # Optimal + "optimal_storage", + "optimal_layout", # bool indicating if the operator handles its own prepacking. If this is True, # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. @@ -72,17 +100,90 @@ def __init__( texture_impl: Optional[TextureImplFeatures] = None, buffer_impl: bool = False, resize_fn: bool = False, + optimal_storage: Optional[VkStorageType] = None, + optimal_layout: Optional[VkMemoryLayout] = None, handles_own_prepacking: bool = False, check_node_fn: Optional[Callable] = None, ): self.texture_impl: Optional[TextureImplFeatures] = texture_impl self.buffer_impl: bool = buffer_impl self.resize_fn: bool = resize_fn + self.optimal_storage: Optional[VkStorageType] = optimal_storage + self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout self.handles_own_prepacking: bool = handles_own_prepacking self.check_node_fn: Callable = allow_node if check_node_fn is not None: self.check_node_fn = check_node_fn + def propose_storage_type(self) -> Optional[VkStorageType]: + """ + Propose a storage type that should be used for this operator. A proposal can be + made if one of the following is true: + 1. The operator specifies an optimal storage type + 2. Only one storage type is supported. + + If both storage types are supported and no optimal storage type is specified, + then None is returned to indicate that there is no preference in storage type. + """ + if self.optimal_storage is not None: + return self.optimal_storage + + if self.texture_impl is not None and not self.buffer_impl: + return VkStorageType.TEXTURE_3D + elif self.buffer_impl and self.texture_impl is None: + return VkStorageType.BUFFER + + return None + + def supported_storage_types(self) -> Set[VkStorageType]: + """ + Return the set of storage types supported by this operator. + """ + storage_types = set() + if self.texture_impl is not None: + storage_types.add(VkStorageType.TEXTURE_3D) + if self.buffer_impl: + storage_types.add(VkStorageType.BUFFER) + + return storage_types + + def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]: + """ + Given a storage type as a precondition, propose a memory layout that should be + used for this operator. A proposal can be made if one of the following is true: + 1. The operator specifies an optimal memory layout + 2. Only one memory layout is supported. + + If multiple memory layouts are supported and no optimal memory layout is + specified then return None to indicate that the "best" memory layout for the + operator is ambiguous. + """ + if self.optimal_layout is not None: + return self.optimal_layout + + if storage == VkStorageType.TEXTURE_3D: + assert self.texture_impl is not None + possible_layouts = self.texture_impl.valid_memory_layouts() + if len(possible_layouts) == 1: + return next(iter(possible_layouts)) + + return None + + def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]: + """ + Return the set of memory layouts supported by this operator for a given storage + type. + """ + if storage == VkStorageType.TEXTURE_3D: + assert self.texture_impl is not None + return self.texture_impl.valid_memory_layouts() + else: + return all_memory_layouts + + +####################### +## Operator Registry ## +####################### OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload] @@ -122,8 +223,8 @@ def update_features_impl(op: OpKey): ) def register_ephemeral_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.buffer_impl = True features.resize_fn = True @@ -143,8 +244,8 @@ def register_ephemeral_op(features: OpFeatures): ) def register_binary_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -170,8 +271,8 @@ def register_binary_op(features: OpFeatures): ) def register_unary_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.buffer_impl = True features.resize_fn = True @@ -181,8 +282,8 @@ def register_unary_op(features: OpFeatures): @update_features(exir_ops.edge.aten._to_copy.default) def register_to_copy_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True @@ -220,15 +321,16 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: ) def register_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=True, - supported_layouts=[ - VkMemoryLayout.TENSOR_WIDTH_PACKED, - VkMemoryLayout.TENSOR_CHANNELS_PACKED, - ], + valid_packed_dims={ + PackedDim.WIDTH, + PackedDim.CHANNELS, + }, ) features.buffer_impl = True features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -236,12 +338,13 @@ def register_mm_op(features: OpFeatures): @update_features(exir_ops.edge.aten._weight_int8pack_mm.default) def register_int8_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=False, - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.buffer_impl = True features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -249,11 +352,12 @@ def register_int8_mm_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.linear_weight_int4.default) def register_int4_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=False, - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -266,7 +370,7 @@ def register_int4_mm_op(features: OpFeatures): ) def register_softmax_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -282,7 +386,7 @@ def register_softmax_op(features: OpFeatures): ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True @@ -309,7 +413,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool: ) def register_2d_pool_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.resize_fn = True return features @@ -323,9 +427,11 @@ def register_2d_pool_op(features: OpFeatures): ) def register_convolution_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED features.handles_own_prepacking = True return features @@ -333,9 +439,11 @@ def register_convolution_op(features: OpFeatures): @update_features("llama::sdpa_with_kv_cache") def register_sdpa_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -343,7 +451,7 @@ def register_sdpa_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True return features @@ -352,7 +460,7 @@ def register_rotary_emb_op(features: OpFeatures): @update_features(exir_ops.edge.aten.view_copy.default) def register_view_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -393,7 +501,7 @@ def register_view_op(features: OpFeatures): ) def register_ported_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) return features @@ -408,15 +516,24 @@ def register_ported_op(features: OpFeatures): ) def register_ported_ops_with_prepacking(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.handles_own_prepacking = True return features -## -## Utility Functions -## +####################### +## Utility functions ## +####################### + + +def has_impl(target: OpKey) -> bool: + if not isinstance(target, str): + if target not in vulkan_supported_ops: + return target.name() in vulkan_supported_ops + return target in vulkan_supported_ops + else: + return target in vulkan_supported_ops def get_op_features(target: OpKey) -> OpFeatures: diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS index d68a82ade0..1d1d29f6fb 100644 --- a/backends/vulkan/partitioner/TARGETS +++ b/backends/vulkan/partitioner/TARGETS @@ -13,6 +13,7 @@ runtime.python_library( ], deps = [ "//executorch/backends/vulkan:op_registry", + "//executorch/backends/vulkan:utils_lib", "//executorch/backends/vulkan:vulkan_preprocess", "//executorch/exir:delegate", "//executorch/exir:lib", diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 2e916fd581..c851eeb4da 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -9,12 +9,23 @@ import logging from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple -import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema +import executorch.backends.vulkan.utils as utils import torch -from executorch.backends.vulkan.op_registry import vulkan_supported_ops +from executorch.backends.vulkan.op_registry import ( + get_op_features, + has_impl, + OpFeatures, + vulkan_supported_ops, +) + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend + from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -24,7 +35,6 @@ from executorch.exir.backend.utils import tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops -from torch._subclasses.fake_tensor import FakeTensor from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner @@ -40,104 +50,140 @@ class VulkanSupportedOperators(OperatorSupportBase): - def __init__(self, require_dynamic_shape: bool = False) -> None: + def __init__( + self, texture_limits: utils.ImageExtents, require_dynamic_shape: bool = False + ) -> None: super().__init__() self.require_dynamic_shapes = require_dynamic_shape - # The tensor dim limit is to guard against tensors with one or more - # large dimensions, which cannot be represented by an image texture due - # to the texture axis limits. - self.tensor_dim_limit = 16384 - - # pyre-ignore - def node_val_is_compatible(self, node_val: Any) -> bool: - # Skip nodes that don't have a value - if node_val is None: - return True + self.texture_limits: utils.ImageExtents = texture_limits - # TODO(ssjia) support symbolic ints - if isinstance(node_val, torch.SymInt): - return False - - if isinstance(node_val, FakeTensor): - # Vulkan currently only supports tensors of up to 4D - if len(node_val.shape) > 4: - return False + def op_node_is_compatible( + self, node: torch.fx.Node, features: Optional[OpFeatures] = None + ) -> Tuple[bool, str]: + """ + Check if a given node is compatible with the Vulkan delegate's implementation + of the operator called by the node. Each tensor argument participating in the + operator call must be able to be represented with a (storage type, memory layout) + combination that is supported by the operator implementation. + """ + target = node.target + # Account for custom operators + if node.target == torch.ops.higher_order.auto_functionalized: + first_arg = node.args[0] + assert isinstance(first_arg, torch._ops.OpOverload) + target = first_arg.name() - # bool dtype not currently supported - if node_val.dtype == torch.bool: - return False + # Extract the features for the node's operator, if no override was provided + if features is None: + if not has_impl(target): + return False, "no operator implementation" + features = get_op_features(target) - for dim in node_val.shape: - if dim > self.tensor_dim_limit: - return False + valid_texture_layouts = utils.possible_node_memory_layouts( + node, self.texture_limits + ) + for arg in node.args: + if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): + arg_texture_layouts = utils.possible_node_memory_layouts( + arg, self.texture_limits + ) + valid_texture_layouts = valid_texture_layouts.intersection( + arg_texture_layouts + ) + + # If there are no valid texture memory layouts, then buffer storage must be + # supported by the operator implementation. + if len(valid_texture_layouts) == 0: + # TODO: once memory metadata tagging pass is implemented, check that the + # op impl supports buffers instead + return False, "requires buffer representation" + + op_available_layouts = features.supported_memory_layouts( + VkStorageType.TEXTURE_3D + ) - if isinstance(node_val, (list, tuple)): - for item in node_val: - if not self.node_val_is_compatible(item): - return False + is_compatible = any( + layout in op_available_layouts for layout in valid_texture_layouts + ) + if not is_compatible: + return False, "Required texutre memory layout not supported" - return True + return is_compatible, "Op is compatible" - def all_args_compatible(self, node: torch.fx.Node) -> bool: - node_val = node.meta.get("val", None) - if not self.node_val_is_compatible(node_val): - return False + def node_is_compatible( + self, node: torch.fx.Node, features: Optional[OpFeatures] = None + ) -> Tuple[bool, str]: + # TODO(ssjia) support symbolic ints + if utils.is_symint_node(node): + return False, "symint node not supported yet" + elif utils.is_tensor_node(node): + return self.op_node_is_compatible(node, features=features) - for arg in node.args: - if not isinstance(arg, torch.fx.Node): - continue + return False, f"Unsupported node type: {node.format_node()}" - arg_val = arg.meta.get("val", None) - if not self.node_val_is_compatible(arg_val): - return False + def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: + """ + Detect if a node is a permute/transpose that precedes a call to a `mm` or + `addmm` operator. This node can be fused with the `mm` or `addmm` to produce a + `linear` operator. - return True + This function returns two bool values: + 1. The first indicates if this node can be fused into a linear node + 2. The second indicates if the overall linear op can be executed with Vulkan - def is_linear_permute(self, node: torch.fx.Node) -> bool: + The node will be partitioned only if both are true. + """ if node.target not in [ exir_ops.edge.aten.t_copy.default, exir_ops.edge.aten.permute_copy.default, ]: - return False + return False, False if len(node.users) != 1: - return False + return False, False first_user = list(node.users.keys())[0] if first_user.target in [ exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, ]: - # Only mark this node if the overall linear op is valid - if self.all_args_compatible(first_user): - return True + # Only mark this node if the target linear op is valid + if self.node_is_compatible(first_user)[0]: + return True, True + else: + return True, False - return False + return False, False - def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: + def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: """ Scalar tensors are usually converted to scalar values in the graph via` scalar_tensor[0].item()` in Python, which translates to a chain of `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. This function marks the entire chain as supported by the Vulkan delegate. - Later, within vulkan_preprocess there will be a graph transform which - replaces the chain with passing in the scalar tensor directly. + Later, within vulkan_preprocess there will be a graph transform which replaces + the chain with passing in the scalar tensor directly. + + Similar to the `is_linear_permute` function, this function has 2 return values. """ if node.target == exir_ops.edge.aten.select_copy.int: if len(node.users) != 1: - return False + return False, False # pyre-ignore if node.args[0].meta["val"].numel() != 1: - return False + return False, False + + local_scalar_dense = list(node.users.keys())[0] + if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: + return False, False - user = list(node.users.keys())[0] - return user.target == torch.ops.aten._local_scalar_dense.default + return self.is_in_local_scalar_dense_chain(local_scalar_dense) if node.target == torch.ops.aten._local_scalar_dense.default: - return True + return True, all(self.node_is_compatible(user)[0] for user in node.users) - return False + return False, False def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": @@ -148,26 +194,35 @@ def log_skip(self, node: torch.fx.Node, reason: str) -> None: def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: - r = self._is_node_supported(submodules, node) + r = self._is_node_supported(node) return r - def _is_node_supported( - self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node - ) -> bool: + def _is_node_supported(self, node: torch.fx.Node) -> bool: target = node.target if node.target == torch.ops.higher_order.auto_functionalized: first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() - if self.is_linear_permute(node): + is_linear_permute, target_linear_is_compatible = self.is_linear_permute(node) + if is_linear_permute and target_linear_is_compatible: return True + elif is_linear_permute: + # Skip so that the permute can be fused into a linear by another backend + self.log_skip(node, "permute node of non compatible linear node") + return False - if self.is_in_local_scalar_dense_chain(node): + is_in_local_scalar_dense_chain, dst_node_is_compatible = ( + self.is_in_local_scalar_dense_chain(node) + ) + if is_in_local_scalar_dense_chain and dst_node_is_compatible: return True + elif is_in_local_scalar_dense_chain: + self.log_skip(node, "local scalar dense of incompatible op node") + return False if target not in vulkan_supported_ops: - self.log_skip(node, "not in vulkan_supported_ops") + self.log_skip(node, "no operator implementation") return False features = vulkan_supported_ops[target] @@ -180,19 +235,38 @@ def _is_node_supported( self.log_skip(node, "no dynamic shape support") return False - return self.all_args_compatible(node) + is_compatible, reason = self.node_is_compatible(node, features=features) + if not is_compatible: + self.log_skip(node, reason) + + return is_compatible def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: compile_specs = [] for key, value in compile_options.items(): - if isinstance( - value, (vk_graph_schema.VkStorageType, vk_graph_schema.VkMemoryLayout) - ): + if isinstance(value, (VkStorageType, VkMemoryLayout)): value_bytes = int(value).to_bytes(4, byteorder="little") compile_specs.append(CompileSpec(key, value_bytes)) + if key == "texture_limits": + compile_specs.append( + CompileSpec( + "texture_limits_x", int(value[0]).to_bytes(4, byteorder="little") + ) + ) + compile_specs.append( + CompileSpec( + "texture_limits_y", int(value[1]).to_bytes(4, byteorder="little") + ) + ) + compile_specs.append( + CompileSpec( + "texture_limits_z", int(value[2]).to_bytes(4, byteorder="little") + ) + ) + # Unhandled options are ignored return compile_specs @@ -200,7 +274,10 @@ def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: @final class VulkanPartitioner(Partitioner): - def __init__(self, compile_options: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + compile_options: Optional[Dict[str, Any]] = None, + ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: self.options = compile_options @@ -218,9 +295,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # subgraphs containing the nodes with the tags partition_tags = {} + texture_limits: utils.ImageExtents = self.options.get( + "texture_limits", utils.DEFAULT_TEXTURE_LIMITS + ) capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - VulkanSupportedOperators(self.options.get("require_dynamic_shapes", False)), + VulkanSupportedOperators( + texture_limits, + require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + ), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 0d3b17cccc..9785b34951 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -253,6 +253,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ ":custom_ops_lib", + ":utils_lib", "//caffe2:torch", "//executorch/exir/dialects:lib", "//executorch/backends/vulkan/serialization:lib", diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index ae0b8c6940..4264e94271 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -4,11 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from enum import IntEnum +from typing import Set, Tuple + import torch + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from torch._export.utils import is_buffer, is_param +from torch._subclasses.fake_tensor import FakeTensor + from torch.export import ExportedProgram +## +## Node type determination +## + def is_get_attr_node(node: torch.fx.Node) -> bool: return isinstance(node, torch.fx.Node) and node.op == "get_attr" @@ -28,3 +42,131 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: or is_buffer(program, node) or is_constant(program, node) ) + + +def is_symint_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a SymInt value + """ + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], torch.SymInt): + return True + + return False + + +def is_tensor_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a tensor value, or a collection of tensor values + """ + # All nodes with tensor values are tagged by the SpecPropPass transform + if "spec" in node.meta: + return True + + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], FakeTensor): + return True + + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + return all(isinstance(x, FakeTensor) for x in node.meta["val"]) + + return False + + +## +## Memory Layout, Storage Type Determination +## + +ImageExtents = Tuple[int, int, int] + +DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) + + +class PackedDim(IntEnum): + WIDTH = 0 + HEIGHT = 1 + CHANNELS = 2 + + +all_packed_dims: Set[PackedDim] = { + PackedDim.WIDTH, + PackedDim.HEIGHT, + PackedDim.CHANNELS, +} + +all_storage_types: Set[VkStorageType] = { + VkStorageType.BUFFER, + VkStorageType.TEXTURE_3D, +} + +all_memory_layouts: Set[VkMemoryLayout] = { + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_HEIGHT_PACKED, + VkMemoryLayout.TENSOR_CHANNELS_PACKED, +} + + +def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: + """ + Calculate the image extents that will be used to represent a tensor with the given sizes + and memory layout in the Vulkan Delegate. + """ + width = sizes[-1] if len(sizes) >= 1 else 1 + height = sizes[-2] if len(sizes) >= 2 else 1 + channels = sizes[-3] if len(sizes) >= 3 else 1 + batch = sizes[0] if len(sizes) >= 4 else 1 + + if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED: + width = (width + 3) // 4 + elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED: + height = (height + 3) // 4 + elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: + channels = (channels + 3) // 4 + else: + raise RuntimeError(f"Unsupported memory layout {layout}") + + return width, height, channels * batch + + +def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool: + return all(extents[i] <= limits[i] for i in range(len(extents))) + + +def valid_texture_memory_layouts( + tensor_sizes: torch.Size, texture_limits: ImageExtents +) -> Set[VkMemoryLayout]: + """ + Given tensor sizes, determine the set of memory layouts which will prodice a texture + that can fit within the specified device limits. + """ + valid_layouts = set() + for layout in list(all_memory_layouts): + extents = required_image_extents(tensor_sizes, layout) + if extents_are_valid(extents, texture_limits): + valid_layouts.add(layout) + + return valid_layouts + + +def possible_node_memory_layouts( + node: torch.fx.Node, texture_limits: ImageExtents +) -> Set[VkMemoryLayout]: + """ + Given a node, determine the set of memory layouts which can be used to represent all + tensors involved in the computation. + """ + assert is_tensor_node(node) + if isinstance(node.meta["val"], FakeTensor): + return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) + valid_layouts = set() + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + for fake_tensor in node.meta["val"]: + valid_layouts = valid_layouts.union( + valid_texture_memory_layouts(fake_tensor.shape, texture_limits) + ) + + return valid_layouts From 836d5561a61877507b6d5891485725996bb6b32c Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 5 Nov 2024 12:33:22 -0800 Subject: [PATCH 05/59] [ET-VK] Introduce memory metadata tagging pass (#6669) * [ET-VK] Refine paritioner to account for storage type and memory layout Pull Request resolved: https://github.com/pytorch/executorch/pull/6635 ## Context There are a variety of ways that tensors can be represented in Vulkan. The two main descriptors for how a tensor is laid out in memory is: 1. Storage Type (buffer or texture) 2. Memory Layout (which dim is packed along a texel, which dim has a stride of 1, etc.) Due to the differences between buffers and textures, and the differences between different memory layouts, an implementation for an operator may only support a specific set of (storage type, memory layout) combinations. Furthermore, if an operator implementation supports multiple (storage type, memory layout) combinations, there may be a "preferred" setting which results in optimal performance. These changes lay the foundation for the implementation of a memory metadata tagging graph transform, which will make sure that all tensors participating in an operator call is has a valid/optimal (storage type, memory layout) setting, and insert transition operators to transfer input tensors to the correct memory settings when necessary. An additional change that is required arises from the fact that in Vulkan, there is a limit on texture and buffer sizes. Therefore, the partitioner needs to account for the storage types and memory layouts supported by the operator implementation, and check if all tensors participating in a computation can be represented with some storage type, memory layout combination supported by the implementation. ## Changes Improvements to the operator registry: * Introduce utility functions to check the optimal and enabled storage types and memory layouts for an operator Improvements to the Partitioner: * Account for the storage types and memory layouts supported by an operator when deciding if a node should be partitioned * Improved logic for fusable ops (i.e. the permute/transpose before a mm which can be fused into linear) to check if the final target op is supported in Vulkan, and only partition those nodes if so. Otherwise, don't partition it so that it can be fused by another backend. ghstack-source-id: 251883705 @exported-using-ghexport Differential Revision: [D65428843](https://our.internmc.facebook.com/intern/diff/D65428843/) * [ET-VK] Introduce memory metadata tagging pass Pull Request resolved: https://github.com/pytorch/executorch/pull/6636 ## Context As title; implements the memory metadata tagging graph transform described in the dependent diff. See the comments for more details. ghstack-source-id: 251884020 @exported-using-ghexport Differential Revision: [D65428842](https://our.internmc.facebook.com/intern/diff/D65428842/) --------- Co-authored-by: Stephen Jia --- backends/vulkan/_passes/TARGETS | 30 ++- backends/vulkan/_passes/__init__.py | 2 + .../vulkan/_passes/tag_memory_meta_pass.py | 236 ++++++++++++++++++ .../vulkan/partitioner/vulkan_partitioner.py | 8 +- .../serialization/vulkan_graph_builder.py | 16 ++ .../serialization/vulkan_graph_schema.py | 6 + backends/vulkan/targets.bzl | 2 + backends/vulkan/utils.py | 45 +++- backends/vulkan/vulkan_preprocess.py | 69 ++++- examples/models/llama/export_llama_lib.py | 2 +- extension/llm/export/partitioner_lib.py | 4 +- 11 files changed, 404 insertions(+), 16 deletions(-) create mode 100644 backends/vulkan/_passes/tag_memory_meta_pass.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index cf50f170cf..ed3d847933 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -16,6 +16,20 @@ runtime.python_library( ], ) +runtime.python_library( + name = "int4_weight_only_quantizer", + srcs = [ + "int4_weight_only_quantizer.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//executorch/backends/vulkan:custom_ops_lib", + "//pytorch/ao:torchao", + ] +) + runtime.python_library( name = "remove_local_scalar_dense", srcs = ["remove_local_scalar_dense_ops.py"], @@ -30,17 +44,18 @@ runtime.python_library( ) runtime.python_library( - name = "int4_weight_only_quantizer", - srcs = [ - "int4_weight_only_quantizer.py", - ], + name = "tag_memory_meta_pass", + srcs = ["tag_memory_meta_pass.py"], visibility = [ "//executorch/backends/...", ], deps = [ - "//executorch/backends/vulkan:custom_ops_lib", - "//pytorch/ao:torchao", - ] + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/backends/vulkan:utils_lib", + "//executorch/backends/vulkan/serialization:lib", + ], ) runtime.python_library( @@ -56,5 +71,6 @@ runtime.python_library( ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_local_scalar_dense", + ":tag_memory_meta_pass" ] ) diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index cfdb7c6eee..8823553ab1 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -5,9 +5,11 @@ from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) +from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "RemoveLocalScalarDenseOpsTransform", + "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py new file mode 100644 index 0000000000..fd0bd3648e --- /dev/null +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from copy import deepcopy +from typing import Set + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.op_registry import get_op_features, has_impl + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult + +from torch._subclasses.fake_tensor import FakeTensor + +from torch.fx.passes.tools_common import NodeList +from torch.fx.passes.utils.fuser_utils import topo_sort + +logger: logging.Logger = logging.getLogger("") +logger.setLevel(logging.INFO) + + +def set_memory_metadata( + node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout +) -> None: + utils.set_node_spec_attr(node, "vk_storage_type", storage) + utils.set_node_spec_attr(node, "vk_memory_layout", layout) + + +class TagMemoryMetaPass(ExportPass): + """ + There are a variety of ways that tensors can be represented in Vulkan. The two main + descriptors for how a tensor is laid out in memory is: + + 1. Storage Type (buffer or texture) + 2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.) + + Due to the differences between buffers and textures, and the differences between + different memory layouts, an implementation for an operator may only support a + specific set of (storage type, memory layout) combinations. + + Furthermore, if an operator implementation supports multiple (storage type, memory + layout) combinations, there may be a "preferred" setting which results in optimal + performance. + + This pass is responsible for ensuring that all tensors participating in an operator + call have a valid/optimal (storage type, memory layout) setting, and insert + transition operators to transfer input tensors to the correct memory settings when + necessary. + """ + + def __init__( + self, + texture_limits: utils.ImageExtents, + default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, + default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, + ): + super().__init__() + self.default_storage: VkStorageType = default_storage_type + self.default_layout: VkMemoryLayout = default_memory_layout + self.texture_limits = texture_limits + + def propose_node_storage( + self, + node: torch.fx.Node, + ) -> VkStorageType: + """ + Uses the operator registry to determine the storage type that should be used for + a given node. The storage type is determined with the following priorities: + 1. In some cases, a tensor involved in the computation may be too large to be + represented as a texture. If this is the case, the node is "opinionated" and + buffer representation must be used. + 1. If the operator called by the node indicates an optimal storage type, or only + supports a single storage type, use that storage type. If either is true, + then the node is considered to be opinionated as well. If multiple storage + and no preferred storage type is indicated, then the node is not opinionated; + go to the next step. + 2. If the node's arguments already have memory metadata annotations, then + preserve the settings of the first argument. Otherwise, proceed to the next + step. + 3. Recursively search the node's uses to see if any subsequent uses are + opinionated; inherit the settings of the first opinionated node. If no + opinionated user can be found, then proceed to the last step. + 4. Use the default storage type setting. + """ + # The node may have an input/output tensor that is too big to be stored in a + # texture. In this case, buffer storage must be used. Note that the partitioner + # has already checked for the fact that buffer storage is supported by the + # operator. + if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0: + return VkStorageType.BUFFER + + valid_storage_types: Set[VkStorageType] = utils.all_storage_types + + # pyre-ignore + if has_impl(node.target): + # pyre-ignore + features = get_op_features(node.target) + valid_storage_types = features.supported_storage_types() + storage = features.propose_storage_type() + if storage is not None: + return storage + + for arg in node.args: + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta["val"], FakeTensor + ): + storage = utils.get_node_storage_type(arg) + if storage is not None and storage in valid_storage_types: + return storage + + # If no storage type has been resolved yet, assume the optimal storage type of + # the first opinionated user. This search is recursive. + for user in node.users: + optimal_storage = self.propose_node_storage(user) + if optimal_storage is not None: + return optimal_storage + + if self.default_storage in valid_storage_types: + return self.default_storage + else: + return next(iter(valid_storage_types)) + + def propose_node_layout( + self, + node: torch.fx.Node, + storage: VkStorageType, + ) -> VkMemoryLayout: + """ + Performs the same steps as propose_node_storage, but detects the memory layout + that should be used for the specific storage type. The same prioritization logic + is applied. + """ + valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts + # pyre-ignore + if has_impl(node.target): + # pyre-ignore + features = get_op_features(node.target) + valid_layouts = features.supported_memory_layouts(storage) + layout = features.propose_memory_layout(storage) + if layout is not None: + return layout + + for arg in node.args: + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta["val"], FakeTensor + ): + layout = utils.get_node_memory_layout(arg) + if layout is not None and layout in valid_layouts: + return layout + + # If no storage type has been resolved yet, assume the optimal storage type of + # the first opinionated user. This search is recursive. + for user in node.users: + optimal_storage = self.propose_node_layout(user, storage) + if optimal_storage is not None: + return optimal_storage + + # As a last resort, return the default storage type that should be used. + if self.default_layout in valid_layouts: + return self.default_layout + else: + return next(iter(valid_layouts)) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) + + for node in sorted_nodes: + if not isinstance(node.meta["val"], FakeTensor): + continue + + if node.target == exir_ops.edge.et_vk.prepack.default: + continue + + storage = self.propose_node_storage(node) + layout = self.propose_node_layout(node, storage) + + set_memory_metadata(node, storage, layout) + + inserting_transitions_for_node = False + for i, arg in enumerate(node.args): + if not isinstance(arg, torch.fx.Node): + continue + if not isinstance(arg.meta["val"], FakeTensor): + continue + + arg_storage = utils.get_node_storage_type(arg) + arg_layout = utils.get_node_memory_layout(arg) + + if arg_storage is None: + utils.set_node_spec_attr(arg, "vk_storage_type", storage) + arg_storage = storage + if arg_layout is None: + utils.set_node_spec_attr(arg, "vk_memory_layout", layout) + arg_layout = layout + + if arg_storage == storage and arg_layout == layout: + continue + + if not inserting_transitions_for_node: + inserting_transitions_for_node = True + logger.info( + f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" + ) + + logger.info( + f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" + ) + + # Insert a clone node to copy the original tensor to a tensor with the + # desired storage type and memory layout. + with graph_module.graph.inserting_before(node): + clone_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.clone.default, + (arg,), + ) + clone_node.meta["val"] = arg.meta["val"] + clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + clone_node.meta["spec"].const = False + set_memory_metadata(clone_node, storage, layout) + arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) + + return PassResult(graph_module, True) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index c851eeb4da..f1fd47fb2b 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -94,9 +94,11 @@ def op_node_is_compatible( # If there are no valid texture memory layouts, then buffer storage must be # supported by the operator implementation. if len(valid_texture_layouts) == 0: - # TODO: once memory metadata tagging pass is implemented, check that the - # op impl supports buffers instead - return False, "requires buffer representation" + compatible = VkStorageType.BUFFER in features.supported_storage_types() + reason = "op is compatible" + if not compatible: + reason = "op requires buffers which is not supported by op impl" + return compatible, reason op_available_layouts = features.supported_memory_layouts( VkStorageType.TEXTURE_3D diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index bc77bc40cf..8144747212 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -12,6 +12,11 @@ import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema import torch + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.utils import ( is_constant, is_get_attr_node, @@ -169,6 +174,15 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: if spec.mem_obj_id is not None: mem_obj_id = spec.mem_obj_id + storage_type = VkStorageType.DEFAULT_STORAGE + memory_layout = VkMemoryLayout.DEFAULT_LAYOUT + if hasattr(spec, "vk_storage_type"): + # pyre-ignore[16] + storage_type = spec.vk_storage_type + if hasattr(spec, "vk_memory_layout"): + # pyre-ignore[16] + memory_layout = spec.vk_memory_layout + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( @@ -177,6 +191,8 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: dims=spec.shape, constant_id=constant_id, mem_obj_id=mem_obj_id, + storage_type=storage_type, + memory_layout=memory_layout, ) ) ) diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 8197f705b5..35113bc623 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -37,6 +37,9 @@ class VkStorageType(IntEnum): TEXTURE_2D = 2 DEFAULT_STORAGE = 255 + def __str__(self) -> str: + return self.name + class VkMemoryLayout(IntEnum): TENSOR_WIDTH_PACKED = 0 @@ -44,6 +47,9 @@ class VkMemoryLayout(IntEnum): TENSOR_CHANNELS_PACKED = 2 DEFAULT_LAYOUT = 255 + def __str__(self) -> str: + return self.name + @dataclass class VkTensor: diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 9785b34951..9521bcacdb 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -223,6 +223,8 @@ def define_common_targets(is_fbcode = False): ], deps = [ "//caffe2:torch", + "//executorch/exir:tensor", + "//executorch/backends/vulkan/serialization:lib", ] ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 4264e94271..2e9fbba01c 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from enum import IntEnum -from typing import Set, Tuple +from typing import Optional, Set, Tuple import torch @@ -13,6 +13,9 @@ VkMemoryLayout, VkStorageType, ) + +from executorch.exir.tensor import TensorSpec + from torch._export.utils import is_buffer, is_param from torch._subclasses.fake_tensor import FakeTensor @@ -170,3 +173,43 @@ def possible_node_memory_layouts( ) return valid_layouts + + +## +## TensorSpec Utils +## + + +def set_node_spec_attr(node: torch.fx.Node, attr: str, value): + assert "spec" in node.meta + spec = node.meta["spec"] + if isinstance(spec, TensorSpec): + setattr(spec, attr, value) + elif isinstance(spec, list) or isinstance(spec, tuple): + for s in spec: + assert isinstance(s, TensorSpec) + setattr(s, attr, value) + else: + raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}") + + +def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True): + assert "spec" in node.meta + spec = node.meta["spec"] + if isinstance(spec, TensorSpec): + return getattr(spec, attr) if hasattr(spec, attr) else None + elif isinstance(spec, list) or isinstance(spec, tuple): + if return_first: + return getattr(spec[0], attr) if hasattr(spec, attr) else None + else: + return [getattr(s, attr) if hasattr(s, attr) else None for s in spec] + else: + raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}") + + +def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]: + return get_node_spec_attr(node, "vk_storage_type") + + +def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: + return get_node_spec_attr(node, "vk_memory_layout") diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 96eee198f4..f0a5fd6725 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -6,7 +6,9 @@ # pyre-strict -from typing import final, List +from typing import Any, Dict, final, List + +import executorch.backends.vulkan.utils as utils from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.transforms.fuse_batch_norm_with_conv import ( @@ -20,9 +22,14 @@ from executorch.backends.vulkan._passes import ( insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, + TagMemoryMetaPass, ) from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( serialize_vulkan_graph, ) @@ -78,6 +85,24 @@ def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: return program +def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: + options = {} + for spec in compile_specs: + if spec.key == "storage_type_override": + options[spec.key] = VkStorageType( + int.from_bytes(spec.value, byteorder="little") + ) + if spec.key == "memory_layout_override": + options[spec.key] = VkMemoryLayout( + int.from_bytes(spec.value, byteorder="little") + ) + if spec.key in {"texture_limits_x", "texture_limits_y", "texture_limits_z"}: + options[spec.key] = int.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored + + return options + + @final class VulkanBackend(BackendDetails): @classmethod @@ -87,6 +112,25 @@ def preprocess( # noqa: C901 program: ExportedProgram, module_compile_spec: List[CompileSpec], ) -> PreprocessResult: + compile_options = parse_compile_spec(module_compile_spec) + limits_x = compile_options.get( + "texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0] + ) + limits_y = compile_options.get( + "texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1] + ) + limits_z = compile_options.get( + "texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2] + ) + texture_limits = (limits_x, limits_y, limits_z) + + default_storage_type = compile_options.get( + "storage_type_override", VkStorageType.TEXTURE_3D + ) + default_memory_layout = compile_options.get( + "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED + ) + program = unsafe_remove_auto_functionalized_pass(program) # First, apply passes that fuse/remove operators to consolidate the graph @@ -122,10 +166,31 @@ def preprocess( # noqa: C901 ], ) + # Optionally apply the memory metadata tagging pass, which will insert storage + # type and memory layout transition nodes to ensure that all tensor arguments + # to an operator is in a supported or optimal configuration. If this pass is not + # applied, there will be a risk that some operators recieve arguments with + # memory settings that are not supported by the implementation. + if not compile_options.get("skip_tag_memory_metadata", False): + program = apply_passes( + program, + [ + TagMemoryMetaPass( + texture_limits, + default_storage_type=default_storage_type, + default_memory_layout=default_memory_layout, + ), + ], + ) + # Finally, apply dynamic shape passes and memory planning pass. These passes # must be applied only when the graph structure is finalized. program = apply_passes( - program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()] + program, + [ + ConstraintBasedSymShapeEvalPass(), + MemoryPlanningPass(), + ], ) graph_builder = VkGraphBuilder( diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index f3822b6866..23b3589c2a 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -622,7 +622,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 partitioners.append( get_vulkan_partitioner( args.dtype_override, - args.quantization_mode, + args.enable_dynamic_shape, ) ) modelname = f"vulkan_{modelname}" diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index d966de9a25..6f4b95e3d0 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -32,7 +32,7 @@ def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True): def get_vulkan_partitioner( - dtype_override: Optional[str] = None, quantization_mode: Optional[str] = None + dtype_override: Optional[str] = None, enable_dynamic_shape: bool = False ): assert ( dtype_override == "fp32" or dtype_override is None @@ -41,7 +41,7 @@ def get_vulkan_partitioner( VulkanPartitioner, ) - return VulkanPartitioner({"require_dynamic_shapes": True}) + return VulkanPartitioner({"require_dynamic_shapes": enable_dynamic_shape}) def get_mps_partitioner(use_kv_cache: bool = False): From c5b88cc21508339034341657b17f37ba621692a7 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Tue, 5 Nov 2024 13:19:56 -0800 Subject: [PATCH 06/59] Add in-memory log buffer in Android JNI Differential Revision: D65474006 Pull Request resolved: https://github.com/pytorch/executorch/pull/6656 --- extension/android/CMakeLists.txt | 2 +- extension/android/jni/jni_layer.cpp | 81 ++++++++++++++++++- .../java/org/pytorch/executorch/Module.java | 5 ++ .../org/pytorch/executorch/NativePeer.java | 4 + 4 files changed, 88 insertions(+), 4 deletions(-) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 31f24b3979..c96cfeb5d7 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -190,4 +190,4 @@ target_include_directories( target_compile_options(executorch_jni PUBLIC ${_common_compile_options}) -target_link_libraries(executorch_jni ${link_libraries}) +target_link_libraries(executorch_jni ${link_libraries} log) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index a6f0045725..479da28806 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -33,8 +33,45 @@ #include #include +using namespace executorch::extension; +using namespace torch::executor; + #ifdef __ANDROID__ #include +#include +#include + +// Number of entries to store in the in-memory log buffer. +const size_t log_buffer_length = 16; + +struct log_entry { + et_timestamp_t timestamp; + et_pal_log_level_t level; + std::string filename; + std::string function; + size_t line; + std::string message; + + log_entry( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) + : timestamp(timestamp), + level(level), + filename(filename), + function(function), + line(line), + message(message, length) {} +}; + +namespace { +std::vector log_buffer_; +std::mutex log_buffer_mutex_; +} // namespace // For Android, write to logcat void et_pal_emit_log_message( @@ -45,6 +82,15 @@ void et_pal_emit_log_message( size_t line, const char* message, size_t length) { + std::lock_guard guard(log_buffer_mutex_); + + while (log_buffer_.size() >= log_buffer_length) { + log_buffer_.erase(log_buffer_.begin()); + } + + log_buffer_.emplace_back( + timestamp, level, filename, function, line, message, length); + int android_log_level = ANDROID_LOG_UNKNOWN; if (level == 'D') { android_log_level = ANDROID_LOG_DEBUG; @@ -60,9 +106,6 @@ void et_pal_emit_log_message( } #endif -using namespace executorch::extension; -using namespace torch::executor; - namespace executorch::extension { class TensorHybrid : public facebook::jni::HybridClass { public: @@ -391,12 +434,44 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return jresult; } + facebook::jni::local_ref> + readLogBuffer() { +#ifdef __ANDROID__ + std::lock_guard guard(log_buffer_mutex_); + + const auto size = log_buffer_.size(); + facebook::jni::local_ref> ret = + facebook::jni::JArrayClass::newArray(size); + + for (auto i = 0u; i < size; i++) { + const auto& entry = log_buffer_[i]; + // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL MESSAGE". + std::stringstream ss; + ss << "[" << entry.timestamp << " " << entry.function << " " + << entry.filename << ":" << entry.line << "] " + << static_cast(entry.level) << " " << entry.message; + + facebook::jni::local_ref jstr_message = + facebook::jni::make_jstring(ss.str().c_str()); + (*ret)[i] = jstr_message; + } + + return ret; +#else + return facebook::jni::JArrayClass::newArray(0); +#endif + } + static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), makeNativeMethod("forward", ExecuTorchJni::forward), makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), + +#ifdef __ANDROID__ + makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer), +#endif }); } }; diff --git a/extension/android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/src/main/java/org/pytorch/executorch/Module.java index 608439548a..879b88c5f2 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/src/main/java/org/pytorch/executorch/Module.java @@ -99,6 +99,11 @@ public int loadMethod(String methodName) { return mNativePeer.loadMethod(methodName); } + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ + public String[] readLogBuffer() { + return mNativePeer.readLogBuffer(); + } + /** * Explicitly destroys the native torch::jit::Module. Calling this method is not required, as the * native object will be destroyed when this object is garbage-collected. However, the timing of diff --git a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java index 2cf2ee53d7..a5487a4702 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -54,4 +54,8 @@ public void resetNative() { */ @DoNotStrip public native int loadMethod(String methodName); + + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ + @DoNotStrip + public native String[] readLogBuffer(); } From 735e019f7315524ff239d12f8871d8bb390e198d Mon Sep 17 00:00:00 2001 From: David Lin Date: Tue, 5 Nov 2024 14:17:27 -0800 Subject: [PATCH 07/59] Fix broken apple tests Differential Revision: D65490319 Pull Request resolved: https://github.com/pytorch/executorch/pull/6664 --- extension/data_loader/file_data_loader.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/data_loader/file_data_loader.cpp b/extension/data_loader/file_data_loader.cpp index 0324751bfa..f5a3b94d84 100644 --- a/extension/data_loader/file_data_loader.cpp +++ b/extension/data_loader/file_data_loader.cpp @@ -76,7 +76,7 @@ FileDataLoader::~FileDataLoader() { ::close(fd_); } -Result getFDFromUri(const char* file_descriptor_uri) { +static Result getFDFromUri(const char* file_descriptor_uri) { // check if the uri starts with the prefix "fd://" ET_CHECK_OR_RETURN_ERROR( strncmp( From f7e26d749eacc5717aff39ca1622d0aa849ae13f Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 5 Nov 2024 14:30:29 -0800 Subject: [PATCH 08/59] [llama-mm] Add export-friendly tile position embedding (#6671) Summary: Before we make a decision on whether torchtune takes this export-friendly version of `TilePositionEmbedding`, we put it under `extension/llm` so that users can start to use it. Added unit tests to make sure the behavior is the same as the reference implementation in torchtune and export/AOTI/ET all working properly. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: fe65ec6c590b6579ac68847cd2e3d4a09921b7e5 Pull Request resolved: https://github.com/pytorch/executorch/pull/6650 Co-authored-by: Mengwei Liu --- extension/llm/modules/README.md | 14 + extension/llm/modules/__init__.py | 15 ++ extension/llm/modules/_position_embeddings.py | 243 ++++++++++++++++++ extension/llm/modules/test/__init__.py | 0 .../modules/test/test_position_embeddings.py | 118 +++++++++ pytest.ini | 1 + 6 files changed, 391 insertions(+) create mode 100644 extension/llm/modules/README.md create mode 100644 extension/llm/modules/__init__.py create mode 100644 extension/llm/modules/_position_embeddings.py create mode 100644 extension/llm/modules/test/__init__.py create mode 100644 extension/llm/modules/test/test_position_embeddings.py diff --git a/extension/llm/modules/README.md b/extension/llm/modules/README.md new file mode 100644 index 0000000000..3694f8b155 --- /dev/null +++ b/extension/llm/modules/README.md @@ -0,0 +1,14 @@ +## Export Friendly Modules + +Modules in this directory are: +* Extending `torch.nn.Module`. +* Guranteed to work out of the box with `torch.export.export()` and `torch.aot_compile()`. +* Guranteed to be able to work with ExecuTorch. + +All modules should be covered by unit tests to make sure they are: +1. giving the same output as the reference implementation in PyTorch or torchtune +2. export friendly +3. AOTI friendly +4. ExecuTorch friendly + +Notice that these modules are subject to change (may upstream to torchtune) so proceed with caution. diff --git a/extension/llm/modules/__init__.py b/extension/llm/modules/__init__.py new file mode 100644 index 0000000000..38245bf935 --- /dev/null +++ b/extension/llm/modules/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._position_embeddings import ( + replace_tile_positional_embedding, + TilePositionalEmbedding, +) + +__all__ = [ + "TilePositionalEmbedding", + "replace_tile_positional_embedding", +] diff --git a/extension/llm/modules/_position_embeddings.py b/extension/llm/modules/_position_embeddings.py new file mode 100644 index 0000000000..0c6a4f6ed9 --- /dev/null +++ b/extension/llm/modules/_position_embeddings.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# An torch.export() friendly version of torchtune's positional embeddings. +# Added torch._check() to make sure guards on symints are enforced. +# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py + +import logging +from typing import Any, Dict, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +class TilePositionalEmbedding(nn.Module): + """ + Positional embedding for tiles, different for every tile, same for every token within a tile. + + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`torchtune.modules.vision_transformer.VisionTransformer`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + embed_dim (int): The dimensionality of each tile embedding. + """ + + def __init__( + self, + max_num_tiles: int, + embed_dim: int, + ): + super().__init__() + self.max_num_tiles = max_num_tiles + self.embed_dim = embed_dim + + scale = embed_dim**-0.5 + self.embedding = nn.Parameter( + scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim) + ) + self.gate = nn.Parameter(torch.zeros(1)) + + # Register load hook to interpolate positional embeddings + self._register_load_state_dict_pre_hook(self._load_state_dict_hook) + + # TODO: Switch to public method after 2.5 is stable + @torch.no_grad() + def _load_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + *args: Tuple[Any], + **kwargs: Dict[str, Any], + ): + """ + Interpolates positional embeddings to accomodate different number of tiles, + in case the model was instantiated with different + settings than the one you are loading the state dict from. + + For more info, check self._dynamic_resize function. + + Args: + state_dict (Dict[str, Any]): The state dict to load. + prefix (str): The prefix of the state dict. + *args (Tuple[Any]): Additional positional arguments. + **kwargs (Dict[str, Any]): Additional keyword arguments. + + Raises: + ValueError: if the shape of the loaded embedding is not compatible with the current embedding. + ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. + ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. + """ + + embedding = state_dict.get(prefix + "embedding") + + if embedding is not None: + + # ckpt pos emb + ( + tgt_max_num_tiles_x, + tgt_max_num_tiles_y, + tgt_num_tokens, + tgt_emb, + ) = self.embedding.shape + + # instantiated pos emb + ( + inpt_max_num_tiles_x, + inpt_max_num_tiles_y, + inpt_num_tokens, + inpt_emb, + ) = state_dict[prefix + "embedding"].shape + + # sanity check + if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb: + raise ValueError( + "Expected embedding shape to be (..., num_tokens, tgt_emb) to match" + f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}" + ) + + if inpt_max_num_tiles_x != inpt_max_num_tiles_y: + raise ValueError( + "Expected max_num_tiles_x, max_num_tiles_y to be equal but found, but found" + f"(max_num_tiles_x, max_num_tiles_y, 1, embed_dim) = {self.embedding.shape}" + ) + + # resize ckpt to match instantiated shape + embedding_new = self._resize_position_embedding( + embedding, tgt_max_num_tiles=tgt_max_num_tiles_x + ) + + # update state dict + state_dict[prefix + "embedding"] = embedding_new + if embedding_new.shape != self.embedding.shape: + raise ValueError( + "Expected embedding shape and embedding_new.shape to match" + f" but found shapes {self.embedding.shape} and {embedding_new.shape}" + ) + + @staticmethod + def _resize_position_embedding( + embedding: torch.Tensor, tgt_max_num_tiles: int + ) -> torch.Tensor: + """ + Interpolates positional embeddings to accomodate a different max_num_tiles. These + are the only dimensions that changes during interpolation. + + Args: + embedding (torch.Tensor): torch.Tensor with shape (max_num_tiles, max_num_tiles, 1, embed_dim + tgt_max_num_tiles (int): The number of tiles to resize to. + + Returns: + torch.Tensor: The resized embedding. + + Example: + >>> import torch + >>> # create dummy embedding + >>> embedding = torch.arange(2*2*2*2).reshape(2, 2, 2, 2).float() + >>> resized_embed = _dynamic_resize(embedding, tgt_max_num_tiles=1) + >>> print(resized_embed.shape) + >>> torch.Size([1, 1, 2, 2]) + """ + # set max_num_tiles to the last dimension + embedding = embedding.permute(2, 3, 0, 1) + + embedding = F.interpolate( + embedding, + size=(tgt_max_num_tiles, tgt_max_num_tiles), + mode="bilinear", + align_corners=True, + ) + # permute to the original shape + embedding = embedding.permute(2, 3, 0, 1) + return embedding + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: + """ + args: + x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), + representing the aspect ratio of the image before tile-cropping, e.g. (2,1). + returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + torch._check(n_tiles <= self.max_num_tiles) + + for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + n_tiles_h = n_tiles_h.item() + n_tiles_w = n_tiles_w.item() + + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. n_tiles_h, n_tiles_w. + torch._check_is_size(n_tiles_h) + torch._check_is_size(n_tiles_w) + torch._check(n_tiles_h >= 1) + torch._check(n_tiles_w >= 1) + torch._check(n_tiles_h <= self.max_num_tiles) + torch._check(n_tiles_w <= self.max_num_tiles) + # TODO: Remove this once pytorch/pytorch#120288 is fixed + padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) + pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] + + # We need to do a clone here in order to make this model export + # friendly as the reshape is collapsing dim 0 and dim 1 into a + # single dim. + pos_embed = pos_embed.clone() + pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim) + + x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) + torch._check_is_size(n_non_padded_tiles) + torch._check(n_non_padded_tiles < x.size(1)) + x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() + x = x[:, :n_tiles, :, :] + + return x + + +def replace_tile_positional_embedding(model: nn.Module) -> nn.Module: + """ + Replace the tile positional embedding from torchtune with an export-friendly one. + Recursively searches the submodules of the model and replaces the tile positional embedding if found. + Args: + model (nn.Module): The model to replace the tile positional embedding in. + + Returns: + nn.Module: The model after replacing the tile positional embedding. + + """ + from torchtune.models.clip._position_embeddings import ( + TilePositionalEmbedding as TuneTilePositionalEmbedding, + ) + + for name, module in model.named_children(): + if isinstance(module, TuneTilePositionalEmbedding): + logging.info( + f"Replacing tile positional embedding in {name} with export-friendly one." + ) + max_num_tiles, _, _, embed_dim = module.embedding.shape + mod = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + ) + mod.load_state_dict(module.state_dict()) + setattr( + model, + name, + mod, + ) + else: + replace_tile_positional_embedding(module) + return model diff --git a/extension/llm/modules/test/__init__.py b/extension/llm/modules/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/extension/llm/modules/test/test_position_embeddings.py b/extension/llm/modules/test/test_position_embeddings.py new file mode 100644 index 0000000000..cf4e7e7f05 --- /dev/null +++ b/extension/llm/modules/test/test_position_embeddings.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import tempfile +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.extension.llm.modules import ( + replace_tile_positional_embedding, + TilePositionalEmbedding, +) +from executorch.runtime import Runtime +from torch._inductor.package import load_package, package_aoti +from torchtune.models.clip import TilePositionalEmbedding as TuneTilePositionalEmbedding + + +class TilePositionalEmbeddingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tpe = TilePositionalEmbedding(4, 1280) + self.ref_tpe = TuneTilePositionalEmbedding(4, 1280) + self.x = torch.randn(1, 4, 1600, 1280) + self.aspect_ratio = torch.tensor([[1, 1]]) + num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4) + num_tokens = torch.export.Dim("num_tokens", min=1, max=1600) + + self.dynamic_shape = { + 0: 1, # batch + 1: num_tiles_dim, # num tiles + 2: num_tokens, # num tokens + 3: 1280, # embedding dim + } + + def test_tile_positional_embedding_smoke(self): + y = self.tpe(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_export(self): + + tpe_ep = torch.export.export( + self.tpe, + (self.x, self.aspect_ratio), + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + + y = tpe_ep.module()(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_aoti(self): + so = torch._export.aot_compile( + self.tpe, + args=(self.x, self.aspect_ratio), + options={"aot_inductor.package": True}, + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = package_aoti(os.path.join(tmpdir, "tpe.pt2"), so) + tpe_aoti = load_package(path) + + y = tpe_aoti(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_et(self): + tpe_ep = torch.export.export( + self.tpe, + (self.x, self.aspect_ratio), + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + et_program = to_edge( + tpe_ep, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[ + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_scalar.default, + torch.ops.aten._local_scalar_dense.default, + ] + ), + ).to_executorch() + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + y = method.execute((self.x, self.aspect_ratio)) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y[0], ref_y)) + + def test_replace_tile_positional_embedding(self): + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.tpe = TuneTilePositionalEmbedding(4, 1280) + + def forward(self, x, aspect_ratio): + return self.tpe(x, aspect_ratio) + + m = Module() + m = replace_tile_positional_embedding(m) + self.assertTrue(isinstance(m.tpe, TilePositionalEmbedding)) diff --git a/pytest.ini b/pytest.ini index 3666c9c879..a5041504ae 100644 --- a/pytest.ini +++ b/pytest.ini @@ -38,6 +38,7 @@ addopts = # backends/xnnpack backends/xnnpack/test # extension/ + extension/llm/modules/test extension/pybindings/test # Runtime runtime From 068f43c141013103a406ac0dcaf48f4efcaacfc4 Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Wed, 6 Nov 2024 10:00:09 +0800 Subject: [PATCH 09/59] Qualcomm AI Engine Direct - Quantizer refine for qat (#6513) * [Qualcomm AI Engine Direct - Quantizer refine for qat] - Reorginize qualcomm/quantizer - Split quantizer/utils.py to -- qconfig -- annotators -- observers directory - Change coresponding callees - Rename get_default_Nbit_qnn_ptq_config to get_NaNw_qnn_ptq_config - Add 16a4w conv test* (It is not compared with original model) * Fix baed on comments - Move and rename param_observer.py to per_channel_param_observer.py - Add todo to merge qconfig * Add a comment - Add todo for per_channel_param_observer.py * [Fix lint] --------- Co-authored-by: Joey Tsai --- .../quantizer/{utils.py => annotators.py} | 447 ++--------------- .../qualcomm/quantizer/custom_annotation.py | 6 +- .../observers/per_channel_param_observer.py | 104 ++++ backends/qualcomm/quantizer/qconfig.py | 464 ++++++++++++++++++ backends/qualcomm/quantizer/quantizer.py | 140 +++--- backends/qualcomm/tests/test_qnn_delegate.py | 25 +- backends/qualcomm/tests/utils.py | 38 +- backends/qualcomm/utils/utils.py | 2 +- examples/qualcomm/oss_scripts/fastvit.py | 18 +- examples/qualcomm/oss_scripts/llama2/llama.py | 8 +- examples/qualcomm/scripts/export_example.py | 7 +- examples/qualcomm/utils.py | 96 ++-- extension/llm/export/quantizer_lib.py | 19 +- 13 files changed, 790 insertions(+), 584 deletions(-) rename backends/qualcomm/quantizer/{utils.py => annotators.py} (68%) create mode 100644 backends/qualcomm/quantizer/observers/per_channel_param_observer.py create mode 100644 backends/qualcomm/quantizer/qconfig.py diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/annotators.py similarity index 68% rename from backends/qualcomm/quantizer/utils.py rename to backends/qualcomm/quantizer/annotators.py index dc3d2a6841..275da567e8 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -5,29 +5,16 @@ # LICENSE file in the root directory of this source tree. import numbers import operator -from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Sequence, Tuple import torch - -from torch import Tensor from torch._ops import OpOverload -from torch._subclasses import FakeTensor - -from torch.ao.quantization.fake_quantize import ( - default_fake_quant, - FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( - FixedQParamsObserver, - MinMaxObserver, - MovingAverageMinMaxObserver, - PerChannelMinMaxObserver, - UniformQuantizationObserverBase, -) +from torch._subclasses import FakeTensor +from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.ao.quantization.observer import FixedQParamsObserver from torch.ao.quantization.quantizer import ( DerivedQuantizationSpec, QuantizationAnnotation, @@ -40,397 +27,12 @@ ) from torch.fx import Node - -class ParamObserver(UniformQuantizationObserverBase): - def __init__( - self, - ch_axis=0, - use_mse=True, - steps=100, - dtype=torch.int8, - qscheme=torch.per_channel_symmetric, - reduce_range=False, - quant_min=None, - quant_max=None, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps, # noqa: B008 - is_dynamic=False, - **kwargs, - ) -> None: - super().__init__( - dtype=dtype, - qscheme=qscheme, - reduce_range=reduce_range, - quant_min=quant_min, - quant_max=quant_max, - factory_kwargs=factory_kwargs, - eps=eps, - is_dynamic=is_dynamic, - **kwargs, - ) - - factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) - self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) - self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) - self.ch_axis = ch_axis - self.use_mse = use_mse - self.steps = steps - self.calibrated = False - - def to_ch_axis(self, x): - axis_order = list(range(len(x.size()))) - axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis - return torch.flatten(x.permute(axis_order), start_dim=1) - - def mse(self, pred, expect): - loss = (pred - expect).abs().pow(2) - return self.to_ch_axis(loss).mean(1) - - def cosine(self, pred, expect): - target = torch.ones(pred.shape[self.ch_axis]) - pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) - expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) - return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) - - def loss_fn(self, x, new_min, new_max): - scale, offset = self._calculate_qparams(new_min, new_max) - x_q = torch.fake_quantize_per_channel_affine( - x, - scale.data, - offset.data.int(), - self.ch_axis, - self.quant_min, - self.quant_max, - ) - return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) - - def line_search(self, x): - x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) - x_range = torch.max(x_min.abs(), x_max) - optimal_loss = torch.zeros_like(x_min) + 1e9 - - # check which clip range could produce smallest loss - for i in range(1, self.steps + 1): - thres = x_range / self.steps * i - current_loss = self.loss_fn(x, -thres, thres) - x_min = torch.where(current_loss < optimal_loss, -thres, x_min) - x_max = torch.where(current_loss < optimal_loss, thres, x_max) - optimal_loss = torch.min(current_loss, optimal_loss) - - return x_min, x_max - - def forward(self, x_orig): - # since params are static, one calibration is enough - if not self.calibrated: - x = x_orig.detach().to(self.min_val.dtype) - self.min_val, self.max_val = self.line_search(x) - self.calibrated = True - - # return fake-quant result for saturating outliers - scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) - return torch.fake_quantize_per_channel_affine( - x_orig, - scale.data, - zero_point.data.int(), - self.ch_axis, - self.quant_min, - self.quant_max, - ) - - @torch.jit.export - def calculate_qparams(self): - return self._calculate_qparams(self.min_val, self.max_val) - - -@dataclass(eq=True, frozen=True) -class QuantizationConfig: - input_activation: Optional[QuantizationSpec] - output_activation: Optional[QuantizationSpec] - weight: Optional[QuantizationSpec] - bias: Optional[QuantizationSpec | Callable] - - -def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: - def _derive_bias_qparams_fn( - obs_or_fqs: List, - ) -> Tuple[Tensor, Tensor]: - assert ( - len(obs_or_fqs) == 2 - ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( - act_scale, weight_scale - ) - derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) - return (derived_scale, derived_zero) - - input_act = node.args[0] - assert isinstance(input_act, Node) - weight = node.args[1] - assert isinstance(weight, Node) - - return DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - ch_axis=0, - qscheme=torch.per_channel_symmetric, - ) - - -def get_default_8bit_qat_proto(act_symmetric: bool = False) -> QuantizationConfig: - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=default_fake_quant, - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( - observer=MovingAverageMinMaxObserver - ), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=default_fake_quant, - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_default_8bit_qnn_ptq_config( - act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_16a8w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_default_16bit_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min + 1, - quant_max=torch.iinfo(torch.int16).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - # torch does not support uint16 quantization, use int32 to bypass - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, weight_dtype=torch.int8 -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config +from .qconfig import ( + get_16a16w_qnn_ptq_config, + get_16a4w_qnn_qat_config, + get_8a8w_qnn_qat_config, + QuantizationConfig, +) QUANT_ANNOTATION_KEY = "quantization_annotation" @@ -901,19 +503,34 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non scale = 1 / (q_max - q_min + 1) - # make sigmoid map to the range between 0~1 - out_act_quantization_spec = QuantizationSpec( + bias_obs_ctr = observer = FixedQParamsObserver.with_args( + scale=scale, + zero_point=0, dtype=quantization_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, quant_max=q_max, quant_min=q_min, - observer_or_fake_quant_ctr=FixedQParamsObserver.with_args( + ) + if quantization_config in ( + get_8a8w_qnn_qat_config(), + get_16a4w_qnn_qat_config(), + ): + bias_obs_ctr = FixedQParamsFakeQuantize.with_args( + observer=observer, scale=scale, zero_point=0, dtype=quantization_config.output_activation.dtype, qscheme=torch.torch.per_tensor_affine, quant_max=q_max, quant_min=q_min, - ), + ) + + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + quant_max=q_max, + quant_min=q_min, + observer_or_fake_quant_ctr=bias_obs_ctr, qscheme=torch.torch.per_tensor_affine, ) @@ -1086,7 +703,7 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 - input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight + input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -1115,7 +732,7 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 - input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight + input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -1258,7 +875,7 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> _annotate_input_qspec_map( node, weight_node, - get_default_16bit_qnn_ptq_config().weight, + get_16a16w_qnn_ptq_config().weight, ) else: _annotate_input_qspec_map( diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index db82172a9e..9d6dea8a97 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -6,12 +6,12 @@ from typing import Sequence import torch +from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, - get_default_8bit_qnn_ptq_config, + get_8a8w_qnn_ptq_config, QuantizationConfig, ) -from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from executorch.exir.dialects._ops import ops as exir_ops from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -110,7 +110,7 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): # Annotate 16a8w for matmul op to get better performance quantization_config_16a8w = get_16a8w_qnn_ptq_config() # Annotate 8a8w for second input of matmul until past_kv_cache - quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: if "nn_module_stack" in node.meta: diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py new file mode 100644 index 0000000000..d556dfa4ba --- /dev/null +++ b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py @@ -0,0 +1,104 @@ +import torch +from torch.ao.quantization.observer import UniformQuantizationObserverBase + + +# TODO move to torch/ao/quantization/observer.py. +class PerChannelParamObserver(UniformQuantizationObserverBase): + def __init__( + self, + ch_axis=0, + use_mse=True, + steps=100, + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.ch_axis = ch_axis + self.use_mse = use_mse + self.steps = steps + self.calibrated = False + + def to_ch_axis(self, x): + axis_order = list(range(len(x.size()))) + axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis + return torch.flatten(x.permute(axis_order), start_dim=1) + + def mse(self, pred, expect): + loss = (pred - expect).abs().pow(2) + return self.to_ch_axis(loss).mean(1) + + def cosine(self, pred, expect): + target = torch.ones(pred.shape[self.ch_axis]) + pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) + expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) + return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) + + def loss_fn(self, x, new_min, new_max): + scale, offset = self._calculate_qparams(new_min, new_max) + x_q = torch.fake_quantize_per_channel_affine( + x, + scale.data, + offset.data.int(), + self.ch_axis, + self.quant_min, + self.quant_max, + ) + return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) + + def line_search(self, x): + x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) + x_range = torch.max(x_min.abs(), x_max) + optimal_loss = torch.zeros_like(x_min) + 1e9 + + # check which clip range could produce smallest loss + for i in range(1, self.steps + 1): + thres = x_range / self.steps * i + current_loss = self.loss_fn(x, -thres, thres) + x_min = torch.where(current_loss < optimal_loss, -thres, x_min) + x_max = torch.where(current_loss < optimal_loss, thres, x_max) + optimal_loss = torch.min(current_loss, optimal_loss) + + return x_min, x_max + + def forward(self, x_orig): + # since params are static, one calibration is enough + if not self.calibrated: + x = x_orig.detach().to(self.min_val.dtype) + self.min_val, self.max_val = self.line_search(x) + self.calibrated = True + + # return fake-quant result for saturating outliers + scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) + return torch.fake_quantize_per_channel_affine( + x_orig, + scale.data, + zero_point.data.int(), + self.ch_axis, + self.quant_min, + self.quant_max, + ) + + @torch.jit.export + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py new file mode 100644 index 0000000000..e07ca24d90 --- /dev/null +++ b/backends/qualcomm/quantizer/qconfig.py @@ -0,0 +1,464 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, +) +from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec +from torch.fx import Node + + +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec | Callable] + + +def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: + def _derive_bias_qparams_fn( + obs_or_fqs: List, + ) -> Tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( + act_scale, weight_scale + ) + derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) + derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) + return (derived_scale, derived_zero) + + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + return DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + ch_axis=0, + qscheme=torch.per_channel_symmetric, + ) + + +def get_8a8w_qnn_ptq_config( + act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# 4 bits quantization only supports specific ops. +def get_16a4w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a8w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a16w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min + 1, + quant_max=torch.iinfo(torch.int16).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + # torch does not support uint16 quantization, use int32 to bypass + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int8, + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# TODO merge qat and ptq to a fucntion, and use a bool flag to control it +def get_8a8w_qnn_qat_config( + act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + reduce_range=True, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=bias_fake_quant_ctr, + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a4w_qnn_qat_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + reduce_range=True, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=bias_fake_quant_ctr, + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_qat_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int8, + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + reduce_range=True, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer=MovingAveragePerChannelMinMaxObserver, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 9e5aaf782a..50ed07788f 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique -from typing import Callable, Dict, Optional, Sequence, Set +from typing import Callable, Optional, Sequence, Set import torch from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum @@ -22,14 +22,17 @@ from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule -from .utils import ( +from .annotators import OP_ANNOTATOR + +from .qconfig import ( + get_16a16w_qnn_ptq_config, get_16a4w_qnn_ptq_config, + get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, - get_default_16bit_qnn_ptq_config, - get_default_8bit_qat_proto, - get_default_8bit_qnn_ptq_config, + get_8a8w_qnn_ptq_config, + get_8a8w_qnn_qat_config, get_ptq_per_channel_quant_config, - OP_ANNOTATOR, + get_qat_per_channel_quant_config, QuantizationConfig, ) @@ -38,9 +41,10 @@ "QuantDtype", "get_16a4w_qnn_ptq_config", "get_16a8w_qnn_ptq_config", - "get_default_16bit_qnn_ptq_config", - "get_default_8bit_qnn_ptq_config", - "get_default_8bit_qat_proto", + "get_16a16w_qnn_ptq_config", + "get_8a8w_qnn_ptq_config", + "get_8a8w_qnn_qat_config", + "get_16a4w_qnn_qat_config", ] @@ -51,8 +55,39 @@ class QuantDtype(IntEnum): """ use_16a16w = 0 - use_16a4w = 1 - use_8a8w = 2 + use_16a8w = 1 + use_16a4w = 2 + use_8a8w = 3 + + +quant_config_dict = { + # PTQ + (QuantDtype.use_16a16w, False): ( + get_16a16w_qnn_ptq_config, + get_ptq_per_channel_quant_config(torch.uint16, torch.int16), + ), + (QuantDtype.use_16a8w, False): ( + get_16a8w_qnn_ptq_config, + get_ptq_per_channel_quant_config(torch.uint16, torch.int8), + ), + (QuantDtype.use_16a4w, False): ( + get_16a4w_qnn_ptq_config, + get_ptq_per_channel_quant_config(torch.uint16, "int4"), + ), + (QuantDtype.use_8a8w, False): ( + get_8a8w_qnn_ptq_config, + get_ptq_per_channel_quant_config(), + ), + # QAT, + (QuantDtype.use_16a4w, True): ( + get_16a4w_qnn_qat_config, + get_qat_per_channel_quant_config(torch.uint16, "int4"), + ), + (QuantDtype.use_8a8w, True): ( + get_8a8w_qnn_qat_config, + get_qat_per_channel_quant_config(), + ), +} class QnnQuantizer(Quantizer): @@ -60,23 +95,17 @@ class QnnQuantizer(Quantizer): def __init__(self): super().__init__() - self.bit8_quant_config: QuantizationConfig = get_default_8bit_qnn_ptq_config() - self.bit16_quant_config: QuantizationConfig = get_default_16bit_qnn_ptq_config() + self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() - self.bit8_quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() - self.bit16_quant_ops: Set[OpOverload] = set() + self.is_qat = False + self.quant_dtype = QuantDtype.use_8a8w + self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() + self.per_channel_quant_config = get_ptq_per_channel_quant_config() + self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() - self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() - # the weight quantized for activation 8 bits and 16 bits - self.per_channel_weight_dtype: Dict = { - "8bit_act": torch.int8, - "16bit_act": torch.int16, - } - self.per_channel_quant_config = None - def _annotate(self, gm: GraphModule) -> None: for node in gm.graph.nodes: if node.name in self.discard_nodes: @@ -94,29 +123,16 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig """ Priority: 1. is one of use_per_channel_weight_quant_ops - 2. int8 / int16 config + 2. quant config """ if isinstance(op, str): return if op in self.use_per_channel_weight_quant_ops: - if self.per_channel_quant_config is None: - if op in self.bit16_quant_ops: - return get_ptq_per_channel_quant_config( - act_dtype=torch.uint16, - weight_dtype=self.per_channel_weight_dtype["16bit_act"], - ) - return get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=self.per_channel_weight_dtype["8bit_act"], - ) return self.per_channel_quant_config - if op in self.bit8_quant_ops: - return self.bit8_quant_config - - if op in self.bit16_quant_ops: - return self.bit16_quant_config + if op in self.quant_ops: + return self.quant_config print(f"No quant config is implemented for op, {op}") @@ -126,15 +142,6 @@ def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: boo else: self.use_per_channel_weight_quant_ops.difference_update(ops) - def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None: - for op in ops: - assert ( - op in self.SUPPORTED_OPS - ), f"The annotation of op {op} is not implemented" - - self.bit8_quant_ops.remove(op) - self.bit16_quant_ops.add(op) - def add_custom_quant_annotations( self, custom_quant_annotations: Sequence[Callable] ) -> None: @@ -145,10 +152,7 @@ def add_discard_nodes(self, nodes: Sequence[str]) -> None: def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: for op in ops: - if op in self.bit8_quant_ops: - self.bit8_quant_ops.remove(op) - if op in self.bit16_quant_ops: - self.bit16_quant_ops.remove(op) + self.quant_ops.remove(op) def annotate(self, model: GraphModule) -> GraphModule: self._annotate(model) @@ -159,24 +163,22 @@ def annotate(self, model: GraphModule) -> GraphModule: def get_supported_ops(self) -> Set[OpOverload]: return self.SUPPORTED_OPS - def set_bit16_op_quant_config( - self, quantization_config: QuantizationConfig - ) -> None: - self.bit16_quant_config = quantization_config - - def set_bit8_op_quant_config(self, quantization_config: QuantizationConfig) -> None: - self.bit8_quant_config = quantization_config - - def set_per_channel_weight_dtype( - self, - weight_dtype_for_8bit_act: Optional[str | torch.dtype] = None, - weight_dtype_for_16bit_act: Optional[str | torch.dtype] = None, + def set_quant_config( + self, quant_dtype: QuantDtype, is_qat=False, act_observer=None ) -> None: - # TODO accept temporally str type. Remove it when torch support torch.int4 dtype - if weight_dtype_for_8bit_act: - self.per_channel_weight_dtype["8bit_act"] = weight_dtype_for_8bit_act - if weight_dtype_for_16bit_act: - self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act + self.quant_dtype = quant_dtype + self.is_qat = is_qat + if (quant_dtype, is_qat) not in quant_config_dict: + raise RuntimeError( + f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" + ) + + quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ + (quant_dtype, is_qat) + ] + self.quant_config = ( + quant_config_fuc(act_observer) if act_observer else quant_config_fuc() + ) def set_per_channel_conv_quant(self, enable: bool) -> None: conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4bfdedcd4b..64b0490d46 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -698,6 +698,17 @@ def test_qnn_backend_16a4w_conv2d(self): ) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_16a4w_conv2d_qat(self): + modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + prepared = self.get_prepared_qat_module(module, sample_input) + converted = self.get_converted_sgd_trained_module( + module, prepared, sample_input + ) + self.lower_module_and_test_output(converted, sample_input) + def test_qnn_backend_16a4w_layer_norm(self): module = LayerNorm() # noqa: F405 sample_input = (torch.randn(196, 768),) @@ -1063,18 +1074,8 @@ def test_qnn_backend_linear_qat(self): """ module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),) - - module = self.get_prepared_qat_module(module, sample_input) - - optimizer = torch.optim.SGD(module.parameters(), lr=0.1) - criterion = torch.nn.CrossEntropyLoss() - output = module(*sample_input) - loss = criterion(output, module(*sample_input)) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - module = torch.ao.quantization.quantize_pt2e.convert_pt2e(module) + prepared = self.get_prepared_qat_module(module, sample_input) + module = self.get_converted_sgd_trained_module(module, prepared, sample_input) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_log_softmax(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 114493c7d2..d2a3e7c241 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -17,13 +17,7 @@ from executorch import exir from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.qnn_preprocess import QnnBackend -from executorch.backends.qualcomm.quantizer.quantizer import ( - get_16a4w_qnn_ptq_config, - get_default_16bit_qnn_ptq_config, - get_default_8bit_qat_proto, - QnnQuantizer, - QuantDtype, -) +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -405,18 +399,7 @@ def get_qdq_module( quantizer.add_custom_quant_annotations(custom_quant_annotations) quantizer.set_per_channel_conv_quant(is_conv_per_channel) quantizer.set_per_channel_linear_quant(is_linear_per_channel) - - if quant_dtype == QuantDtype.use_8a8w: - pass # default setting - elif quant_dtype == QuantDtype.use_16a16w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) - elif quant_dtype == QuantDtype.use_16a4w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) - quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") - else: - raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + quantizer.set_quant_config(quant_dtype) prepared = prepare_pt2e(m, quantizer) prepared(*inputs) @@ -448,13 +431,28 @@ def get_prepared_qat_module( quantizer.set_per_channel_linear_quant(is_linear_per_channel) if quant_dtype == QuantDtype.use_8a8w: - quantizer.set_bit8_op_quant_config(get_default_8bit_qat_proto()) + quantizer.set_quant_config(quant_dtype, is_qat=True) else: raise RuntimeError("Shuld not be here") prepared = prepare_qat_pt2e(m, quantizer) return torch.ao.quantization.move_exported_model_to_train(prepared) + def get_converted_sgd_trained_module( + self, + ori_module: torch.nn.Module, + prepared: torch.nn.Module, + inputs: Tuple[torch.Tensor], + ) -> torch.fx.GraphModule: + optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) + criterion = torch.nn.CrossEntropyLoss() + output = prepared(*inputs) + loss = criterion(output, ori_module(*inputs)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) + def split_graph(self, graph_module: torch.fx.GraphModule, division: int): class SplitGraph(ExportPass): """ diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 0ea4512abc..cb54412add 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -331,7 +331,7 @@ def _transform( def capture_program( module: torch.nn.Module, inputs: Tuple[torch.Tensor], - custom_pass_config: Set[str] = None, + custom_pass_config: Set[str] = frozenset(), ) -> exir.ExirExportedProgram: ep = torch.export.export(module, inputs) decomposed_ep = ep.run_decompositions(get_decomp_table()) diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index 30fe74f35b..0e2c695ab3 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -10,15 +10,19 @@ import numpy as np import torch - -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.backends.qualcomm.quantizer.utils import ( - _derived_bias_quant_spec, - MovingAverageMinMaxObserver, - ParamObserver, +from executorch.backends.qualcomm.quantizer.annotators import ( QuantizationConfig, QuantizationSpec, ) +from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( + PerChannelParamObserver, +) +from executorch.backends.qualcomm.quantizer.qconfig import ( + _derived_bias_quant_spec, + MovingAverageMinMaxObserver, +) + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_EXPAND_BROADCAST_SHAPE, ) @@ -87,7 +91,7 @@ def main(args): quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_channel_symmetric, ch_axis=0, - observer_or_fake_quant_ctr=ParamObserver.with_args( + observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( **{"steps": 200, "use_mse": True} ), ) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 04569df5c9..9f7198a344 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -56,12 +56,12 @@ def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: This function is specific for matmul op 16a8w. """ + from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, - get_default_8bit_qnn_ptq_config, + get_8a8w_qnn_ptq_config, QuantizationConfig, ) - from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, @@ -119,7 +119,7 @@ def annotate_single_in_single_out( ) def annotate_matmul_input1(node: Node): - quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) while isinstance(node, Node) and node.op == "call_function": if node.target in [ torch.ops.aten.permute.default, @@ -142,11 +142,11 @@ def annotate_matmul_input1(node: Node): def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: + from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_ptq_per_channel_quant_config, QuantizationConfig, ) - from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import Node diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index 2e49a2344b..56169e39a2 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -4,10 +4,7 @@ import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner -from executorch.backends.qualcomm.quantizer.quantizer import ( - get_default_8bit_qnn_ptq_config, - QnnQuantizer, -) +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -64,8 +61,6 @@ def main() -> None: # Get quantizer quantizer = QnnQuantizer() - quant_config = get_default_8bit_qnn_ptq_config() - quantizer.set_bit8_op_quant_config(quant_config) # Typical pytorch 2.0 quantization flow m = torch.export.export(model.eval(), example_inputs).module() diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 06225be2d1..100008e91c 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -16,13 +16,7 @@ import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner -from executorch.backends.qualcomm.quantizer.quantizer import ( - get_16a4w_qnn_ptq_config, - get_default_16bit_qnn_ptq_config, - get_default_8bit_qnn_ptq_config, - QnnQuantizer, - QuantDtype, -) +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -37,7 +31,11 @@ from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torch.ao.quantization.observer import MovingAverageMinMaxObserver -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) class SimpleADB: @@ -187,36 +185,58 @@ def pull_debug_output(self, etdump_path, debug_ouput_path, callback=None): callback() +def ptq_calibrate(captured_model, quantizer, dataset): + annotated_model = prepare_pt2e(captured_model, quantizer) + print("Quantizing(PTQ) the model...") + # calibration + if callable(dataset): + dataset(annotated_model) + else: + for data in dataset: + annotated_model(*data) + return annotated_model + + +def qat_train(ori_model, captured_model, quantizer, dataset): + data, targets = dataset + annotated_model = torch.ao.quantization.move_exported_model_to_train( + prepare_qat_pt2e(captured_model, quantizer) + ) + optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) + criterion = torch.nn.CrossEntropyLoss() + for i, d in enumerate(data): + print(f"Epoch {i}") + if i > 3: + # Freeze quantizer parameters + annotated_model.apply(torch.ao.quantization.disable_observer) + if i > 2: + # Freeze batch norm mean and variance estimates + annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) + + output = annotated_model(*d) + loss = criterion(output, targets[i]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return torch.ao.quantization.quantize_pt2e.convert_pt2e( + torch.ao.quantization.move_exported_model_to_eval(annotated_model) + ) + + def make_quantizer( - quant_dtype: Optional[QuantDtype], + quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, custom_annotations=(), per_channel_conv=True, per_channel_linear=False, act_observer=MovingAverageMinMaxObserver, + is_qat=False, ): quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) quantizer.set_per_channel_conv_quant(per_channel_conv) quantizer.set_per_channel_linear_quant(per_channel_linear) - - if quant_dtype == QuantDtype.use_8a8w: - quantizer.set_bit8_op_quant_config( - get_default_8bit_qnn_ptq_config(act_observer=act_observer) - ) - elif quant_dtype == QuantDtype.use_16a16w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config( - get_default_16bit_qnn_ptq_config(act_observer=act_observer) - ) - elif quant_dtype == QuantDtype.use_16a4w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config( - get_16a4w_qnn_ptq_config(act_observer=act_observer) - ) - quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") - else: - raise AssertionError(f"No support for QuantDtype {quant_dtype}.") - + quantizer.set_quant_config(quant_dtype, is_qat, act_observer) return quantizer @@ -235,18 +255,22 @@ def build_executorch_binary( metadata=None, dump_intermediate_outputs=False, custom_pass_config=frozenset(), + qat_training_data=None, ): if quant_dtype is not None: - quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) captured_model = torch.export.export(model, inputs).module() - annotated_model = prepare_pt2e(captured_model, quantizer) - print("Quantizing the model...") - # calibration - if callable(dataset): - dataset(annotated_model) + if qat_training_data: + quantizer = custom_quantizer or make_quantizer( + quant_dtype=quant_dtype, is_qat=True + ) + # qat training + annotated_model = qat_train( + model, captured_model, quantizer, qat_training_data + ) else: - for data in dataset: - annotated_model(*data) + quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) + # ptq calibration + annotated_model = ptq_calibrate(captured_model, quantizer, dataset) quantized_model = convert_pt2e(annotated_model) edge_prog = capture_program(quantized_model, inputs, custom_pass_config) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index fd368d73f1..ba281864a9 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -144,6 +144,7 @@ def check_embedding_byte_registered(): def get_qnn_quantizer( pt2e_quantize: str, quantization_mode: Optional[str] = None, + is_qat: bool = False, ): try: from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21] @@ -152,8 +153,6 @@ def get_qnn_quantizer( # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` from executorch.backends.qualcomm.quantizer.quantizer import ( - get_16a4w_qnn_ptq_config, - get_default_16bit_qnn_ptq_config, QnnQuantizer, QuantDtype, ) @@ -175,6 +174,7 @@ def get_qnn_quantizer( custom_annotations = () if quant_config == "8a8w": quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16] + qnn_quantizer.set_quant_config(quant_dtype, is_qat=is_qat) elif quant_config == "16a16w": quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w @@ -184,20 +184,17 @@ def get_qnn_quantizer( ) qnn_quantizer.set_per_channel_conv_quant(enable=False) qnn_quantizer.set_per_channel_linear_quant(enable=False) - qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) - qnn_quantizer.set_bit16_op_quant_config( - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - get_default_16bit_qnn_ptq_config(act_observer=MinMaxObserver) + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + qnn_quantizer.set_quant_config( + quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver ) elif quant_config == "16a4w": # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. quant_dtype = QuantDtype.use_16a4w - qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) - qnn_quantizer.set_bit16_op_quant_config( - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + qnn_quantizer.set_quant_config( + quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver ) - qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. custom_annotations = (custom_annotate_llama_matmul_16a8w,) else: From b4c6fe1eeb3888d626799c2d05043224c094e7ca Mon Sep 17 00:00:00 2001 From: cccclai Date: Tue, 5 Nov 2024 19:55:45 -0800 Subject: [PATCH 10/59] Refactor Init function arg Differential Revision: D65499276 Pull Request resolved: https://github.com/pytorch/executorch/pull/6673 --- runtime/executor/test/backend_integration_test.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/runtime/executor/test/backend_integration_test.cpp b/runtime/executor/test/backend_integration_test.cpp index 9180d77aa3..bf9dc0033f 100644 --- a/runtime/executor/test/backend_integration_test.cpp +++ b/runtime/executor/test/backend_integration_test.cpp @@ -55,7 +55,7 @@ class StubBackend final : public BackendInterface { using InitFn = std::function( FreeableBuffer*, ArrayRef, - MemoryAllocator*)>; + BackendInitContext&)>; using ExecuteFn = std::function; using DestroyFn = std::function; @@ -83,8 +83,7 @@ class StubBackend final : public BackendInterface { FreeableBuffer* processed, ArrayRef compile_specs) const override { if (init_fn_) { - return init_fn_.value()( - processed, compile_specs, context.get_runtime_allocator()); + return init_fn_.value()(processed, compile_specs, context); } // Return a benign value otherwise. return nullptr; @@ -351,7 +350,7 @@ TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) { StubBackend::singleton().install_init( [&](FreeableBuffer* processed, ET_UNUSED ArrayRef compile_specs, - ET_UNUSED MemoryAllocator* runtime_allocator) + ET_UNUSED BackendInitContext& backend_init_context) -> Result { init_called = true; processed_data = processed->data(); @@ -395,7 +394,7 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) { StubBackend::singleton().install_init( [&](FreeableBuffer* processed, ET_UNUSED ArrayRef compile_specs, - ET_UNUSED MemoryAllocator* runtime_allocator) + ET_UNUSED BackendInitContext& backend_init_context) -> Result { init_processed = processed; return processed; @@ -492,7 +491,7 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) { StubBackend::singleton().install_init( [&](FreeableBuffer* processed, ET_UNUSED ArrayRef compile_specs, - ET_UNUSED MemoryAllocator* runtime_allocator) + ET_UNUSED BackendInitContext& backend_init_context) -> Result { processed_data = processed->data(); processed->Free(); @@ -606,7 +605,7 @@ TEST_P(DelegateDataAlignmentTest, ExpectedDataAlignment) { StubBackend::singleton().install_init( [&](FreeableBuffer* processed, ET_UNUSED ArrayRef compile_specs, - ET_UNUSED MemoryAllocator* runtime_allocator) + ET_UNUSED BackendInitContext& backend_init_context) -> Result { processed_data = processed->data(); return nullptr; From 179d4954c842b809cada95e27308d44e4e601c0f Mon Sep 17 00:00:00 2001 From: David Lin Date: Tue, 5 Nov 2024 20:34:59 -0800 Subject: [PATCH 11/59] [data loader] move logic for FD data loader out of file_data_loader (#6682) move fd loader Co-authored-by: lind --- .../executor_runner/executor_runner.cpp | 40 +- examples/portable/executor_runner/targets.bzl | 1 + extension/data_loader/file_data_loader.cpp | 84 +--- extension/data_loader/file_data_loader.h | 26 -- .../file_descriptor_data_loader.cpp | 292 ++++++++++++++ .../data_loader/file_descriptor_data_loader.h | 112 ++++++ extension/data_loader/targets.bzl | 15 + .../test/file_data_loader_test.cpp | 97 ----- .../test/file_descriptor_data_loader_test.cpp | 359 ++++++++++++++++++ extension/data_loader/test/targets.bzl | 11 + 10 files changed, 835 insertions(+), 202 deletions(-) create mode 100644 extension/data_loader/file_descriptor_data_loader.cpp create mode 100644 extension/data_loader/file_descriptor_data_loader.h create mode 100644 extension/data_loader/test/file_descriptor_data_loader_test.cpp diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index f1a2d3b8f2..35e58fec03 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -45,6 +46,7 @@ DEFINE_bool( "True if the model_path passed is a file descriptor with the prefix \"fd:///\"."); using executorch::extension::FileDataLoader; +using executorch::extension::FileDescriptorDataLoader; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::HierarchicalAllocator; @@ -56,6 +58,33 @@ using executorch::runtime::Program; using executorch::runtime::Result; using executorch::runtime::Span; +static Result getProgram( + const bool is_fd_uri, + const char* model_path) { + // Create a loader to get the data of the program file. This demonstrates both + // FileDataLoader and FileDescriptorDataLoader. There are other DataLoaders + // that use mmap() or point to data that's already in memory, and users can + // create their own DataLoaders to load from arbitrary sources. + if (!is_fd_uri) { + Result loader = FileDataLoader::from(model_path); + + ET_CHECK_MSG( + loader.ok(), + "FileDataLoader::from() failed: 0x%" PRIx32, + (uint32_t)loader.error()); + return Program::load(&loader.get()); + } else { + Result loader = + FileDescriptorDataLoader::fromFileDescriptorUri(model_path); + + ET_CHECK_MSG( + loader.ok(), + "FileDescriptorDataLoader::fromFileDescriptorUri() failed: 0x%" PRIx32, + (uint32_t)loader.error()); + return Program::load(&loader.get()); + } +} + int main(int argc, char** argv) { executorch::runtime::runtime_init(); @@ -75,18 +104,9 @@ int main(int argc, char** argv) { const char* model_path = FLAGS_model_path.c_str(); const bool is_fd_uri = FLAGS_is_fd_uri; - Result loader = is_fd_uri - ? FileDataLoader::fromFileDescriptorUri(model_path) - : FileDataLoader::from(model_path); - - ET_CHECK_MSG( - loader.ok(), - "FileDataLoader::from() failed: 0x%" PRIx32, - (uint32_t)loader.error()); - // Parse the program file. This is immutable, and can also be reused between // multiple execution invocations across multiple threads. - Result program = Program::load(&loader.get()); + Result program = getProgram(is_fd_uri, model_path); if (!program.ok()) { ET_LOG(Error, "Failed to parse model file %s", model_path); return 1; diff --git a/examples/portable/executor_runner/targets.bzl b/examples/portable/executor_runner/targets.bzl index 9cddaa4ed7..83c63d3a41 100644 --- a/examples/portable/executor_runner/targets.bzl +++ b/examples/portable/executor_runner/targets.bzl @@ -15,6 +15,7 @@ def define_common_targets(): deps = [ "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/data_loader:file_descriptor_data_loader", "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/runner_util:inputs", ], diff --git a/extension/data_loader/file_data_loader.cpp b/extension/data_loader/file_data_loader.cpp index f5a3b94d84..1d097cfd98 100644 --- a/extension/data_loader/file_data_loader.cpp +++ b/extension/data_loader/file_data_loader.cpp @@ -43,8 +43,6 @@ namespace extension { namespace { -static constexpr char kFdFilesystemPrefix[] = "fd:///"; - /** * Returns true if the value is an integer power of 2. */ @@ -76,36 +74,25 @@ FileDataLoader::~FileDataLoader() { ::close(fd_); } -static Result getFDFromUri(const char* file_descriptor_uri) { - // check if the uri starts with the prefix "fd://" +Result FileDataLoader::from( + const char* file_name, + size_t alignment) { ET_CHECK_OR_RETURN_ERROR( - strncmp( - file_descriptor_uri, - kFdFilesystemPrefix, - strlen(kFdFilesystemPrefix)) == 0, + is_power_of_2(alignment), InvalidArgument, - "File descriptor uri (%s) does not start with %s", - file_descriptor_uri, - kFdFilesystemPrefix); - - // strip "fd:///" from the uri - int fd_len = strlen(file_descriptor_uri) - strlen(kFdFilesystemPrefix); - char fd_without_prefix[fd_len + 1]; - memcpy( - fd_without_prefix, - &file_descriptor_uri[strlen(kFdFilesystemPrefix)], - fd_len); - fd_without_prefix[fd_len] = '\0'; + "Alignment %zu is not a power of 2", + alignment); - // check if remaining fd string is a valid integer - int fd = ::atoi(fd_without_prefix); - return fd; -} + // Use open() instead of fopen() to avoid the layer of buffering that + // fopen() does. We will be reading large portions of the file in one shot, + // so buffering does not help. + int fd = ::open(file_name, O_RDONLY); + if (fd < 0) { + ET_LOG( + Error, "Failed to open %s: %s (%d)", file_name, strerror(errno), errno); + return Error::AccessFailed; + } -Result FileDataLoader::fromFileDescriptor( - const char* file_name, - const int fd, - size_t alignment) { // Cache the file size. struct stat st; int err = ::fstat(fd, &st); @@ -132,47 +119,6 @@ Result FileDataLoader::fromFileDescriptor( return FileDataLoader(fd, file_size, alignment, file_name_copy); } -Result FileDataLoader::fromFileDescriptorUri( - const char* file_descriptor_uri, - size_t alignment) { - ET_CHECK_OR_RETURN_ERROR( - is_power_of_2(alignment), - InvalidArgument, - "Alignment %zu is not a power of 2", - alignment); - - auto parsed_fd = getFDFromUri(file_descriptor_uri); - if (!parsed_fd.ok()) { - return parsed_fd.error(); - } - - int fd = parsed_fd.get(); - - return fromFileDescriptor(file_descriptor_uri, fd, alignment); -} - -Result FileDataLoader::from( - const char* file_name, - size_t alignment) { - ET_CHECK_OR_RETURN_ERROR( - is_power_of_2(alignment), - InvalidArgument, - "Alignment %zu is not a power of 2", - alignment); - - // Use open() instead of fopen() to avoid the layer of buffering that - // fopen() does. We will be reading large portions of the file in one shot, - // so buffering does not help. - int fd = ::open(file_name, O_RDONLY); - if (fd < 0) { - ET_LOG( - Error, "Failed to open %s: %s (%d)", file_name, strerror(errno), errno); - return Error::AccessFailed; - } - - return fromFileDescriptor(file_name, fd, alignment); -} - namespace { /** * FreeableBuffer::FreeFn-compatible callback. diff --git a/extension/data_loader/file_data_loader.h b/extension/data_loader/file_data_loader.h index 959684137b..7cf2a92c4a 100644 --- a/extension/data_loader/file_data_loader.h +++ b/extension/data_loader/file_data_loader.h @@ -26,27 +26,6 @@ namespace extension { */ class FileDataLoader final : public executorch::runtime::DataLoader { public: - /** - * Creates a new FileDataLoader that wraps the named file descriptor, and the - * ownership of the file descriptor is passed. This helper is used when ET is - * running in a process that does not have access to the filesystem, and the - * caller is able to open the file and pass the file descriptor. - * - * @param[in] file_descriptor_uri File descriptor with the prefix "fd:///", - * followed by the file descriptor number. - * @param[in] alignment Alignment in bytes of pointers returned by this - * instance. Must be a power of two. - * - * @returns A new FileDataLoader on success. - * @retval Error::InvalidArgument `alignment` is not a power of two. - * @retval Error::AccessFailed `file_name` could not be opened, or its size - * could not be found. - * @retval Error::MemoryAllocationFailed Internal memory allocation failure. - */ - static executorch::runtime::Result fromFileDescriptorUri( - const char* file_descriptor_uri, - size_t alignment = alignof(std::max_align_t)); - /** * Creates a new FileDataLoader that wraps the named file. * @@ -100,11 +79,6 @@ class FileDataLoader final : public executorch::runtime::DataLoader { void* buffer) const override; private: - static executorch::runtime::Result fromFileDescriptor( - const char* file_name, - const int fd, - size_t alignment = alignof(std::max_align_t)); - FileDataLoader( int fd, size_t file_size, diff --git a/extension/data_loader/file_descriptor_data_loader.cpp b/extension/data_loader/file_descriptor_data_loader.cpp new file mode 100644 index 0000000000..48e81fd706 --- /dev/null +++ b/extension/data_loader/file_descriptor_data_loader.cpp @@ -0,0 +1,292 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; + +namespace executorch { +namespace extension { + +namespace { + +static constexpr char kFdFilesystemPrefix[] = "fd:///"; + +/** + * Returns true if the value is an integer power of 2. + */ +static bool is_power_of_2(size_t value) { + return value > 0 && (value & ~(value - 1)) == value; +} + +/** + * Returns the next alignment for a given pointer. + */ +static uint8_t* align_pointer(void* ptr, size_t alignment) { + intptr_t addr = reinterpret_cast(ptr); + if ((addr & (alignment - 1)) == 0) { + // Already aligned. + return reinterpret_cast(ptr); + } + // Bump forward. + addr = (addr | (alignment - 1)) + 1; + return reinterpret_cast(addr); +} +} // namespace + +FileDescriptorDataLoader::~FileDescriptorDataLoader() { + // file_descriptor_uri_ can be nullptr if this instance was moved from, but + // freeing a null pointer is safe. + std::free(const_cast(file_descriptor_uri_)); + // fd_ can be -1 if this instance was moved from, but closing a negative fd is + // safe (though it will return an error). + ::close(fd_); +} + +static Result getFDFromUri(const char* file_descriptor_uri) { + // check if the uri starts with the prefix "fd://" + ET_CHECK_OR_RETURN_ERROR( + strncmp( + file_descriptor_uri, + kFdFilesystemPrefix, + strlen(kFdFilesystemPrefix)) == 0, + InvalidArgument, + "File descriptor uri (%s) does not start with %s", + file_descriptor_uri, + kFdFilesystemPrefix); + + // strip "fd:///" from the uri + int fd_len = strlen(file_descriptor_uri) - strlen(kFdFilesystemPrefix); + char fd_without_prefix[fd_len + 1]; + memcpy( + fd_without_prefix, + &file_descriptor_uri[strlen(kFdFilesystemPrefix)], + fd_len); + fd_without_prefix[fd_len] = '\0'; + + // check if remaining fd string is a valid integer + int fd = ::atoi(fd_without_prefix); + return fd; +} + +Result +FileDescriptorDataLoader::fromFileDescriptorUri( + const char* file_descriptor_uri, + size_t alignment) { + ET_CHECK_OR_RETURN_ERROR( + is_power_of_2(alignment), + InvalidArgument, + "Alignment %zu is not a power of 2", + alignment); + + auto parsed_fd = getFDFromUri(file_descriptor_uri); + if (!parsed_fd.ok()) { + return parsed_fd.error(); + } + + int fd = parsed_fd.get(); + + // Cache the file size. + struct stat st; + int err = ::fstat(fd, &st); + if (err < 0) { + ET_LOG( + Error, + "Could not get length of %s: %s (%d)", + file_descriptor_uri, + ::strerror(errno), + errno); + ::close(fd); + return Error::AccessFailed; + } + size_t file_size = st.st_size; + + // Copy the filename so we can print better debug messages if reads fail. + const char* file_descriptor_uri_copy = ::strdup(file_descriptor_uri); + if (file_descriptor_uri_copy == nullptr) { + ET_LOG(Error, "strdup(%s) failed", file_descriptor_uri); + ::close(fd); + return Error::MemoryAllocationFailed; + } + + return FileDescriptorDataLoader( + fd, file_size, alignment, file_descriptor_uri_copy); +} + +namespace { +/** + * FreeableBuffer::FreeFn-compatible callback. + * + * `context` is actually a ptrdiff_t value (not a pointer) that contains the + * offset in bytes between `data` and the actual pointer to free. + */ +void FreeSegment(void* context, void* data, ET_UNUSED size_t size) { + ptrdiff_t offset = reinterpret_cast(context); + ET_DCHECK_MSG(offset >= 0, "Unexpected offset %ld", (long int)offset); + std::free(static_cast(data) - offset); +} +} // namespace + +Result FileDescriptorDataLoader::load( + size_t offset, + size_t size, + ET_UNUSED const DataLoader::SegmentInfo& segment_info) const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + ET_CHECK_OR_RETURN_ERROR( + offset + size <= file_size_, + InvalidArgument, + "File %s: offset %zu + size %zu > file_size_ %zu", + file_descriptor_uri_, + offset, + size, + file_size_); + + // Don't bother allocating/freeing for empty segments. + if (size == 0) { + return FreeableBuffer(nullptr, 0, /*free_fn=*/nullptr); + } + + // Allocate memory for the FreeableBuffer. + size_t alloc_size = size; + if (alignment_ > alignof(std::max_align_t)) { + // malloc() will align to smaller values, but we must manually align to + // larger values. + alloc_size += alignment_; + } + void* buffer = std::malloc(alloc_size); + if (buffer == nullptr) { + ET_LOG( + Error, + "Reading from %s at offset %zu: malloc(%zd) failed", + file_descriptor_uri_, + offset, + size); + return Error::MemoryAllocationFailed; + } + + // Align. + void* aligned_buffer = align_pointer(buffer, alignment_); + + // Assert that the alignment didn't overflow the buffer. + ET_DCHECK_MSG( + reinterpret_cast(aligned_buffer) + size <= + reinterpret_cast(buffer) + alloc_size, + "aligned_buffer %p + size %zu > buffer %p + alloc_size %zu", + aligned_buffer, + size, + buffer, + alloc_size); + + auto err = load_into(offset, size, segment_info, aligned_buffer); + if (err != Error::Ok) { + // Free `buffer`, which is what malloc() gave us, not `aligned_buffer`. + std::free(buffer); + return err; + } + + // We can't naively free this pointer, since it may not be what malloc() gave + // us. Pass the offset to the real buffer as context. This is the number of + // bytes that need to be subtracted from the FreeableBuffer::data() pointer to + // find the actual pointer to free. + return FreeableBuffer( + aligned_buffer, + size, + FreeSegment, + /*free_fn_context=*/ + reinterpret_cast( + // Using signed types here because it will produce a signed ptrdiff_t + // value, though for us it will always be non-negative. + reinterpret_cast(aligned_buffer) - + reinterpret_cast(buffer))); +} + +Result FileDescriptorDataLoader::size() const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + return file_size_; +} + +ET_NODISCARD Error FileDescriptorDataLoader::load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + ET_CHECK_OR_RETURN_ERROR( + offset + size <= file_size_, + InvalidArgument, + "File %s: offset %zu + size %zu > file_size_ %zu", + file_descriptor_uri_, + offset, + size, + file_size_); + ET_CHECK_OR_RETURN_ERROR( + buffer != nullptr, InvalidArgument, "Provided buffer cannot be null"); + + // Read the data into the aligned address. + size_t needed = size; + uint8_t* buf = reinterpret_cast(buffer); + + while (needed > 0) { + // Reads on macOS will fail with EINVAL if size > INT32_MAX. + const auto chunk_size = std::min( + needed, static_cast(std::numeric_limits::max())); + const auto nread = ::pread(fd_, buf, chunk_size, offset); + if (nread < 0 && errno == EINTR) { + // Interrupted by a signal; zero bytes read. + continue; + } + if (nread <= 0) { + // nread == 0 means EOF, which we shouldn't see if we were able to read + // the full amount. nread < 0 means an error occurred. + ET_LOG( + Error, + "Reading from %s: failed to read %zu bytes at offset %zu: %s", + file_descriptor_uri_, + size, + offset, + nread == 0 ? "EOF" : strerror(errno)); + return Error::AccessFailed; + } + needed -= nread; + buf += nread; + offset += nread; + } + return Error::Ok; +} + +} // namespace extension +} // namespace executorch diff --git a/extension/data_loader/file_descriptor_data_loader.h b/extension/data_loader/file_descriptor_data_loader.h new file mode 100644 index 0000000000..6f51f0f7a6 --- /dev/null +++ b/extension/data_loader/file_descriptor_data_loader.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace executorch { +namespace extension { + +/** + * A DataLoader that loads segments from a file descriptor, allocating the + * memory with `malloc()`. This data loader is used when ET is running in a + * process that does not have access to the filesystem, and the caller is able + * to open the file and pass the file descriptor. + * + * Note that this will keep the file open for the duration of its lifetime, to + * avoid the overhead of opening it again for every load() call. + */ +class FileDescriptorDataLoader final : public executorch::runtime::DataLoader { + public: + /** + * Creates a new FileDescriptorDataLoader that wraps the named file + * descriptor, and the ownership of the file descriptor is passed. + * + * @param[in] file_descriptor_uri File descriptor with the prefix "fd:///", + * followed by the file descriptor number. + * @param[in] alignment Alignment in bytes of pointers returned by this + * instance. Must be a power of two. + * + * @returns A new FileDescriptorDataLoader on success. + * @retval Error::InvalidArgument `alignment` is not a power of two. + * @retval Error::AccessFailed `file_descriptor_uri` is incorrectly formatted, + * or its size could not be found. + * @retval Error::MemoryAllocationFailed Internal memory allocation failure. + */ + static executorch::runtime::Result + fromFileDescriptorUri( + const char* file_descriptor_uri, + size_t alignment = alignof(std::max_align_t)); + + // Movable to be compatible with Result. + FileDescriptorDataLoader(FileDescriptorDataLoader&& rhs) noexcept + : file_descriptor_uri_(rhs.file_descriptor_uri_), + file_size_(rhs.file_size_), + alignment_(rhs.alignment_), + fd_(rhs.fd_) { + const_cast(rhs.file_descriptor_uri_) = nullptr; + const_cast(rhs.file_size_) = 0; + const_cast(rhs.alignment_) = 0; + const_cast(rhs.fd_) = -1; + } + + ~FileDescriptorDataLoader() override; + + ET_NODISCARD + executorch::runtime::Result load( + size_t offset, + size_t size, + const DataLoader::SegmentInfo& segment_info) const override; + + ET_NODISCARD executorch::runtime::Result size() const override; + + ET_NODISCARD executorch::runtime::Error load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const override; + + private: + FileDescriptorDataLoader( + int fd, + size_t file_size, + size_t alignment, + const char* file_descriptor_uri) + : file_descriptor_uri_(file_descriptor_uri), + file_size_(file_size), + alignment_(alignment), + fd_(fd) {} + + // Not safely copyable. + FileDescriptorDataLoader(const FileDescriptorDataLoader&) = delete; + FileDescriptorDataLoader& operator=(const FileDescriptorDataLoader&) = delete; + FileDescriptorDataLoader& operator=(FileDescriptorDataLoader&&) = delete; + + const char* const file_descriptor_uri_; // Owned by the instance. + const size_t file_size_; + const size_t alignment_; + const int fd_; // Owned by the instance. +}; + +} // namespace extension +} // namespace executorch + +namespace torch { +namespace executor { +namespace util { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::extension::FileDescriptorDataLoader; +} // namespace util +} // namespace executor +} // namespace torch diff --git a/extension/data_loader/targets.bzl b/extension/data_loader/targets.bzl index 4886df03a7..fcc7cba541 100644 --- a/extension/data_loader/targets.bzl +++ b/extension/data_loader/targets.bzl @@ -52,6 +52,21 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "file_descriptor_data_loader", + srcs = ["file_descriptor_data_loader.cpp"], + exported_headers = ["file_descriptor_data_loader.h"], + visibility = [ + "//executorch/test/...", + "//executorch/runtime/executor/test/...", + "//executorch/extension/data_loader/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/runtime/core:core", + ], + ) + runtime.cxx_library( name = "mmap_data_loader", srcs = ["mmap_data_loader.cpp"], diff --git a/extension/data_loader/test/file_data_loader_test.cpp b/extension/data_loader/test/file_data_loader_test.cpp index b8921aebb5..1d4f4c1619 100644 --- a/extension/data_loader/test/file_data_loader_test.cpp +++ b/extension/data_loader/test/file_data_loader_test.cpp @@ -40,103 +40,6 @@ class FileDataLoaderTest : public ::testing::TestWithParam { } }; -TEST_P(FileDataLoaderTest, InBoundsFileDescriptorLoadsSucceed) { - // Write some heterogeneous data to a file. - uint8_t data[256]; - for (int i = 0; i < sizeof(data); ++i) { - data[i] = i; - } - TempFile tf(data, sizeof(data)); - - int fd = ::open(tf.path().c_str(), O_RDONLY); - - // Wrap it in a loader. - Result fdl = FileDataLoader::fromFileDescriptorUri( - ("fd:///" + std::to_string(fd)).c_str(), alignment()); - ASSERT_EQ(fdl.error(), Error::Ok); - - // size() should succeed and reflect the total size. - Result size = fdl->size(); - ASSERT_EQ(size.error(), Error::Ok); - EXPECT_EQ(*size, sizeof(data)); - - // Load the first bytes of the data. - { - Result fb = fdl->load( - /*offset=*/0, - /*size=*/8, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), 8); - EXPECT_EQ( - 0, - std::memcmp( - fb->data(), - "\x00\x01\x02\x03" - "\x04\x05\x06\x07", - fb->size())); - - // Freeing should release the buffer and clear out the segment. - fb->Free(); - EXPECT_EQ(fb->size(), 0); - EXPECT_EQ(fb->data(), nullptr); - - // Safe to call multiple times. - fb->Free(); - } - - // Load the last few bytes of the data, a different size than the first time. - { - Result fb = fdl->load( - /*offset=*/sizeof(data) - 3, - /*size=*/3, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), 3); - EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); - } - - // Loading all of the data succeeds. - { - Result fb = fdl->load( - /*offset=*/0, - /*size=*/sizeof(data), - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), sizeof(data)); - EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); - } - - // Loading zero-sized data succeeds, even at the end of the data. - { - Result fb = fdl->load( - /*offset=*/sizeof(data), - /*size=*/0, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_EQ(fb->size(), 0); - } -} - -TEST_P(FileDataLoaderTest, FileDescriptorLoadPrefixFail) { - // Write some heterogeneous data to a file. - uint8_t data[256]; - for (int i = 0; i < sizeof(data); ++i) { - data[i] = i; - } - TempFile tf(data, sizeof(data)); - - int fd = ::open(tf.path().c_str(), O_RDONLY); - - // Wrap it in a loader. - Result fdl = FileDataLoader::fromFileDescriptorUri( - std::to_string(fd).c_str(), alignment()); - ASSERT_EQ(fdl.error(), Error::InvalidArgument); -} - TEST_P(FileDataLoaderTest, InBoundsLoadsSucceed) { // Write some heterogeneous data to a file. uint8_t data[256]; diff --git a/extension/data_loader/test/file_descriptor_data_loader_test.cpp b/extension/data_loader/test/file_descriptor_data_loader_test.cpp new file mode 100644 index 0000000000..0258611cbd --- /dev/null +++ b/extension/data_loader/test/file_descriptor_data_loader_test.cpp @@ -0,0 +1,359 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include +#include +#include +#include + +using namespace ::testing; +using executorch::extension::FileDescriptorDataLoader; +using executorch::extension::testing::TempFile; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; + +class FileDescriptorDataLoaderTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + } + + // The alignment in bytes that tests should use. The values are set by the + // list in the INSTANTIATE_TEST_SUITE_P call below. + size_t alignment() const { + return GetParam(); + } +}; + +TEST_P(FileDescriptorDataLoaderTest, InBoundsFileDescriptorLoadsSucceed) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); + + // Load the first bytes of the data. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/8, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 8); + EXPECT_EQ( + 0, + std::memcmp( + fb->data(), + "\x00\x01\x02\x03" + "\x04\x05\x06\x07", + fb->size())); + + // Freeing should release the buffer and clear out the segment. + fb->Free(); + EXPECT_EQ(fb->size(), 0); + EXPECT_EQ(fb->data(), nullptr); + + // Safe to call multiple times. + fb->Free(); + } + + // Load the last few bytes of the data, a different size than the first time. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) - 3, + /*size=*/3, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 3); + EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); + } + + // Loading all of the data succeeds. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), sizeof(data)); + EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); + } + + // Loading zero-sized data succeeds, even at the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data), + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_EQ(fb->size(), 0); + } +} + +TEST_P(FileDescriptorDataLoaderTest, FileDescriptorLoadPrefixFail) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + std::to_string(fd).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::InvalidArgument); +} + +TEST_P(FileDescriptorDataLoaderTest, InBoundsLoadsSucceed) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); + + // Load the first bytes of the data. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/8, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 8); + EXPECT_EQ( + 0, + std::memcmp( + fb->data(), + "\x00\x01\x02\x03" + "\x04\x05\x06\x07", + fb->size())); + + // Freeing should release the buffer and clear out the segment. + fb->Free(); + EXPECT_EQ(fb->size(), 0); + EXPECT_EQ(fb->data(), nullptr); + + // Safe to call multiple times. + fb->Free(); + } + + // Load the last few bytes of the data, a different size than the first time. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) - 3, + /*size=*/3, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 3); + EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); + } + + // Loading all of the data succeeds. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), sizeof(data)); + EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); + } + + // Loading zero-sized data succeeds, even at the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data), + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_EQ(fb->size(), 0); + } +} + +TEST_P(FileDescriptorDataLoaderTest, OutOfBoundsLoadFails) { + // Create a temp file; contents don't matter. + uint8_t data[256] = {}; + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // Loading beyond the end of the data should fail. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data) + 1, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + EXPECT_NE(fb.error(), Error::Ok); + } + + // Loading zero bytes still fails if it's past the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) + 1, + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + EXPECT_NE(fb.error(), Error::Ok); + } +} + +TEST_P(FileDescriptorDataLoaderTest, BadAlignmentFails) { + // Create a temp file; contents don't matter. + uint8_t data[256] = {}; + TempFile tf(data, sizeof(data)); + + // Creating a loader with default alignment works fine. + { + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + } + + // Bad alignments fail. + const std::vector bad_alignments = {0, 3, 5, 17}; + for (size_t bad_alignment : bad_alignments) { + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), bad_alignment); + ASSERT_EQ(fdl.error(), Error::InvalidArgument); + } +} + +// Tests that the move ctor works. +TEST_P(FileDescriptorDataLoaderTest, MoveCtor) { + // Create a loader. + std::string contents = "FILE_CONTENTS"; + TempFile tf(contents); + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + EXPECT_EQ(fdl->size().get(), contents.size()); + + // Move it into another instance. + FileDescriptorDataLoader fdl2(std::move(*fdl)); + + // Old loader should now be invalid. + EXPECT_EQ( + fdl->load( + 0, + 0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)) + .error(), + Error::InvalidState); + EXPECT_EQ(fdl->size().error(), Error::InvalidState); + + // New loader should point to the file. + EXPECT_EQ(fdl2.size().get(), contents.size()); + Result fb = fdl2.load( + /*offset=*/0, + contents.size(), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + ASSERT_EQ(fb->size(), contents.size()); + EXPECT_EQ(0, std::memcmp(fb->data(), contents.data(), fb->size())); +} + +// Test that the deprecated From method (capital 'F') still works. +TEST_P(FileDescriptorDataLoaderTest, DEPRECATEDFrom) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); +} + +// Run all FileDescriptorDataLoaderTests multiple times, varying the return +// value of `GetParam()` based on the `testing::Values` list. The tests will +// interpret the value as "alignment". +INSTANTIATE_TEST_SUITE_P( + VariedSegments, + FileDescriptorDataLoaderTest, + testing::Values( + 1, + 4, + alignof(std::max_align_t), + 2 * alignof(std::max_align_t), + 128, + 1024)); diff --git a/extension/data_loader/test/targets.bzl b/extension/data_loader/test/targets.bzl index 9c83d6d56b..d424413c1b 100644 --- a/extension/data_loader/test/targets.bzl +++ b/extension/data_loader/test/targets.bzl @@ -38,6 +38,17 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "file_descriptor_data_loader_test", + srcs = [ + "file_descriptor_data_loader_test.cpp", + ], + deps = [ + "//executorch/extension/testing_util:temp_file", + "//executorch/extension/data_loader:file_descriptor_data_loader", + ], + ) + runtime.cxx_test( name = "mmap_data_loader_test", srcs = [ From c438f8dc29f86fafc6bcf670127f435c3ac5257b Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 6 Nov 2024 06:53:47 +0100 Subject: [PATCH 12/59] Refactor pytest config + add default dump dir option Differential Revision: D65481575 Pull Request resolved: https://github.com/pytorch/executorch/pull/6637 --- backends/arm/test/common.py | 85 ++++++++++++++++++++++++-- backends/arm/test/tester/arm_tester.py | 8 +++ 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 2ae86b1d1e..af44fa4474 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -11,6 +11,10 @@ import subprocess import sys import tempfile +from datetime import datetime +from enum import auto, Enum +from pathlib import Path +from typing import Any import pytest @@ -19,7 +23,15 @@ from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.exir.backend.compile_spec_schema import CompileSpec -_enabled_options: list[str] = [] + +class arm_test_options(Enum): + quantize_io = auto() + corstone300 = auto() + dump_path = auto() + date_format = auto() + + +_test_options: dict[arm_test_options, Any] = {} # ==== Pytest hooks ==== @@ -27,19 +39,30 @@ def pytest_addoption(parser): parser.addoption("--arm_quantize_io", action="store_true") parser.addoption("--arm_run_corstone300", action="store_true") + parser.addoption("--default_dump_path", default=None) + parser.addoption("--date_format", default="%d-%b-%H:%M:%S") def pytest_configure(config): if config.option.arm_quantize_io: load_libquantized_ops_aot_lib() - _enabled_options.append("quantize_io") + _test_options[arm_test_options.quantize_io] = True if config.option.arm_run_corstone300: corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55") if not corstone300_exists: raise RuntimeError( "Tests are run with --arm_run_corstone300 but corstone300 FVP is not installed." ) - _enabled_options.append("corstone300") + _test_options[arm_test_options.corstone300] = True + if config.option.default_dump_path: + dump_path = Path(config.option.default_dump_path).expanduser() + if dump_path.exists() and os.path.isdir(dump_path): + _test_options[arm_test_options.dump_path] = dump_path + else: + raise RuntimeError( + f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory." + ) + _test_options[arm_test_options.date_format] = config.option.date_format logging.basicConfig(level=logging.INFO, stream=sys.stdout) @@ -54,6 +77,18 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_if_aot_lib_not_loaded) +def pytest_sessionstart(session): + pass + + +def pytest_sessionfinish(session, exitstatus): + if get_option(arm_test_options.dump_path): + _clean_dir( + get_option(arm_test_options.dump_path), + f"ArmTester_{get_option(arm_test_options.date_format)}.log", + ) + + # ==== End of Pytest hooks ===== @@ -76,7 +111,9 @@ def load_libquantized_ops_aot_lib(): torch.ops.load_library(library_path) -def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: +def is_option_enabled( + option: str | arm_test_options, fail_if_not_enabled: bool = False +) -> bool: """ Returns whether an option is successfully enabled, i.e. if the flag was given to pytest and the necessary requirements are available. @@ -87,7 +124,10 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: The optional parameter 'fail_if_not_enabled' makes the function raise a RuntimeError instead of returning False. """ - if option.lower() in _enabled_options: + if isinstance(option, str): + option = arm_test_options[option.lower()] + + if option in _test_options and _test_options[option]: return True else: if fail_if_not_enabled: @@ -96,6 +136,12 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: return False +def get_option(option: arm_test_options) -> Any | None: + if option in _test_options: + return _test_options[option] + return None + + def maybe_get_tosa_collate_path() -> str | None: """ Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the @@ -219,3 +265,32 @@ def get_u85_compile_spec_unbuilt( .dump_intermediate_artifacts_to(artifact_path) ) return compile_spec + + +def current_time_formated() -> str: + """Return current time as a formated string""" + return datetime.now().strftime(get_option(arm_test_options.date_format)) + + +def _clean_dir(dir: Path, filter: str, num_save=10): + sorted_files: list[tuple[datetime, Path]] = [] + for file in dir.iterdir(): + try: + creation_time = datetime.strptime(file.name, filter) + insert_index = -1 + for i, to_compare in enumerate(sorted_files): + compare_time = to_compare[0] + if creation_time < compare_time: + insert_index = i + break + if insert_index == -1 and len(sorted_files) < num_save: + sorted_files.append((creation_time, file)) + else: + sorted_files.insert(insert_index, (creation_time, file)) + except ValueError: + continue + + if len(sorted_files) > num_save: + for remove in sorted_files[0 : len(sorted_files) - num_save]: + file = remove[1] + file.unlink() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 59d326109d..096bc2b22f 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -22,6 +22,11 @@ ArmQuantizer, get_symmetric_quantization_config, ) +from executorch.backends.arm.test.common import ( + arm_test_options, + current_time_formated, + get_option, +) from executorch.backends.arm.test.runner_utils import ( _get_input_quantization_params, @@ -575,6 +580,9 @@ def _get_tosa_operator_distribution( def _dump_str(to_print: str, path_to_dump: Optional[str] = None): + default_dump_path = get_option(arm_test_options.dump_path) + if not path_to_dump and default_dump_path: + path_to_dump = default_dump_path / f"ArmTester_{current_time_formated()}.log" if path_to_dump: with open(path_to_dump, "a") as fp: fp.write(to_print) From b8e0ef9b2c043fd3246b5b747566125954d25774 Mon Sep 17 00:00:00 2001 From: SaoirseARM <44364573+SaoirseARM@users.noreply.github.com> Date: Wed, 6 Nov 2024 06:00:50 +0000 Subject: [PATCH 13/59] Add cat/stack ops to generic annotator Differential Revision: D65067790 Pull Request resolved: https://github.com/pytorch/executorch/pull/6494 --- backends/arm/quantizer/arm_quantizer.py | 1 - backends/arm/quantizer/arm_quantizer_utils.py | 11 --- .../quantization_annotation/__init__.py | 1 - .../quantization_annotation/cat_annotator.py | 68 ------------------- .../generic_annotator.py | 34 ++++++++-- backends/arm/test/ops/test_slice.py | 12 +--- backends/arm/test/ops/test_split.py | 12 +--- backends/arm/test/ops/test_squeeze.py | 11 +-- backends/arm/test/ops/test_unsqueeze.py | 11 +-- backends/arm/test/ops/test_view.py | 17 ++--- 10 files changed, 44 insertions(+), 134 deletions(-) delete mode 100644 backends/arm/quantizer/quantization_annotation/cat_annotator.py diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 6a68eb2eb9..e61fbc5bbe 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -268,7 +268,6 @@ class ArmQuantizer(Quantizer): "sub", "mul", "mm", - "cat", "one_to_one", "generic", "sum", diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 4a910611bc..a1d7bfe296 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -144,21 +144,10 @@ def is_share_obs_or_fq_op(op: Callable) -> bool: torch.ops.aten.mean.dim, torch.ops.aten.permute.default, torch.ops.aten.permute_copy.default, - torch.ops.aten.squeeze.dim, - torch.ops.aten.squeeze.dims, - torch.ops.aten.squeeze.default, - torch.ops.aten.squeeze_copy.dim, - torch.ops.aten.unsqueeze.default, - torch.ops.aten.unsqueeze_copy.default, # TODO: remove? torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.avg_pool2d.default, - torch.ops.aten.view_copy.default, - torch.ops.aten.view.default, torch.ops.aten.full.default, - torch.ops.aten.slice.Tensor, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes.default, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, operator.getitem, diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index bc3184298f..7eaa837c5b 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -51,7 +51,6 @@ def decorator(annotator: AnnotatorType): from . import ( # noqa adaptive_ang_pool2d_annotator, add_annotator, - cat_annotator, conv_annotator, generic_annotator, linear_annotator, diff --git a/backends/arm/quantizer/quantization_annotation/cat_annotator.py b/backends/arm/quantizer/quantization_annotation/cat_annotator.py deleted file mode 100644 index 6e138cd9de..0000000000 --- a/backends/arm/quantizer/quantization_annotation/cat_annotator.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, cast, List, Optional - -import torch.fx -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -@register_annotator("cat") -def _annotate_cat( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) - cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) - annotated_partitions = [] - for cat_partition in cat_partitions: - annotated_partitions.append(cat_partition.nodes) - cat_node = cat_partition.output_nodes[0] - if arm_quantizer_utils.is_annotated(cat_node): - continue - - input_acts = cast(list[torch.fx.Node], cat_node.args[0]) - input_act0 = input_acts[0] - - input_act_qspec = quantization_config.get_input_act_qspec() - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) - - input_qspec_map = {} - - # First input is set to input qspec from the quantization config. - if isinstance(input_act0, Node): - if not arm_quantizer_utils.is_input_ok_for_quantization(input_act0, gm): - continue - input_qspec_map[input_act0] = input_act_qspec - - # For the rest of the inputs, share qspec with first. - # If we can't quantize any of the inputs, abort annotation. - for input_act in input_acts[1:]: - if isinstance(input_act, Node): - if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm): - continue - if input_act is not input_act0: - input_qspec_map[input_act] = shared_with_input0_qspec - - if input_qspec_map is not None: - cat_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=shared_with_input0_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py index f91df1398e..126051f158 100644 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/generic_annotator.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe - from typing import Callable, List, Optional import torch @@ -24,6 +23,9 @@ # DATA LAYOUT OPS torch.ops.aten.squeeze.default, torch.ops.aten.squeeze_copy.default, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, torch.ops.aten.unsqueeze.default, torch.ops.aten.unsqueeze_copy.default, torch.ops.aten.reshape.default, @@ -33,19 +35,21 @@ # torch.ops.aten.view_as_complex_copy.default, # torch.ops.aten.view_as_real.default, # torch.ops.aten.view_as_real_copy.default, + torch.ops.aten.view.default, torch.ops.aten.view_copy.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor, - # 'concat' should be handled separately as it has a sequence of inputs and - # makes the implementation unnecessary complicated. - # torch.ops.aten.concat.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes.default, torch.ops.aten.transpose.Dimname, torch.ops.aten.transpose.int, torch.ops.aten.transpose_copy.int, torch.ops.aten.tile.default, torch.ops.aten.flip.default, + torch.ops.aten.cat.default, + torch.ops.aten.stack.default, ] @@ -66,15 +70,31 @@ def _annotate_generic( if arm_quantizer_utils.is_annotated(node): continue - input_node = node.args[0] + input_acts = node.args[0] + + # Check to see if there are multiple inputs. + # this allows for stack/cat ops to be annotated + # in a similar way. + has_multi_inputs = isinstance(input_acts, list) + + input_act0 = input_acts[0] if has_multi_inputs else input_acts # Using a non-shared quantization spec here as a SharedQuantizationSpec # can lead to a recursion. _annotate_input_qspec_map( - node, input_node, quantization_config.get_input_act_qspec() + node, input_act0, quantization_config.get_input_act_qspec() ) - _annotate_output_qspec(node, SharedQuantizationSpec((input_node, node))) + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) + + if has_multi_inputs: + # For the rest of the inputs, share qspec with first. + for input_act in input_acts[1:]: + if input_act is not input_act0: + node.meta["quantization_annotation"].input_qspec_map[ + input_act + ] = shared_with_input0_qspec + _annotate_output_qspec(node, shared_with_input0_qspec) arm_quantizer_utils.mark_nodes_as_annotated([node]) annotated_partitions.append([node]) diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 0bab21f907..18db358fdf 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -8,13 +8,9 @@ from typing import Tuple import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) + from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -59,7 +55,6 @@ def _test_slice_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor], permute: bool ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, @@ -68,7 +63,7 @@ def _test_slice_tosa_BI_pipeline( permute_memory_to_nhwc=permute ), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check(["torch.ops.aten.slice.Tensor"]) .to_edge() @@ -84,14 +79,13 @@ def _test_slice_ethos_BI_pipeline( module: torch.nn.Module, test_data: Tuple[torch.Tensor], ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_u55_compile_spec(), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check(["torch.ops.aten.slice.Tensor"]) .to_edge() diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py index 3f6edc0c2b..8ed0e723f1 100644 --- a/backends/arm/test/ops/test_split.py +++ b/backends/arm/test/ops/test_split.py @@ -7,13 +7,9 @@ import unittest import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) + from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -79,14 +75,13 @@ def _test_split_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: test_data_t ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec(), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .to_edge() .partition() @@ -98,14 +93,13 @@ def _test_split_tosa_BI_pipeline( def _test_split_ethosu_BI_pipeline( self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: test_data_t ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=compile_spec, ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check(["torch.ops.aten.split.Tensor"]) .to_edge() diff --git a/backends/arm/test/ops/test_squeeze.py b/backends/arm/test/ops/test_squeeze.py index c9d7d42195..c3f1edf37b 100644 --- a/backends/arm/test/ops/test_squeeze.py +++ b/backends/arm/test/ops/test_squeeze.py @@ -13,14 +13,9 @@ import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -83,14 +78,13 @@ def _test_squeeze_tosa_BI_pipeline( test_data: Tuple[torch.Tensor, Optional[tuple[int]]], export_target: str, ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec(), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({export_target: 1}) .to_edge() @@ -107,10 +101,9 @@ def _test_squeeze_ethosu_BI_pipeline( test_data: Tuple[torch.Tensor, Optional[tuple[int]]], export_target: str, ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({export_target: 1}) .to_edge() diff --git a/backends/arm/test/ops/test_unsqueeze.py b/backends/arm/test/ops/test_unsqueeze.py index 1cc597c066..36bb93b796 100644 --- a/backends/arm/test/ops/test_unsqueeze.py +++ b/backends/arm/test/ops/test_unsqueeze.py @@ -13,14 +13,9 @@ import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -54,14 +49,13 @@ def _test_unsqueeze_tosa_MI_pipeline( def _test_unsqueeze_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int] ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec(), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) .to_edge() @@ -77,14 +71,13 @@ def _test_unsqueeze_ethosu_BI_pipeline( module: torch.nn.Module, test_data: Tuple[torch.Tensor, int], ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=compile_spec, ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) .to_edge() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index fe1f2981da..54e80702e3 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -13,14 +13,9 @@ import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - ArmQuantizer, - get_symmetric_quantization_config, -) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -74,14 +69,13 @@ def _test_view_tosa_MI_pipeline( def _test_view_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec(), ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .quantize() .export() .check_count({"torch.ops.aten.view.default": 1}) .to_edge() @@ -97,10 +91,13 @@ def _test_view_ethos_BI_pipeline( module: torch.nn.Module, test_data: Tuple[torch.Tensor], ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( - ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() .export() .check_count({"torch.ops.aten.view.default": 1}) .to_edge() From 03b1ef26df33efc9de41528195cf85bca497ff6d Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 5 Nov 2024 23:08:41 -0700 Subject: [PATCH 14/59] Remove IR check after aten in arm Differential Revision: D65504497 Pull Request resolved: https://github.com/pytorch/executorch/pull/6677 --- backends/arm/test/ops/test_batch_norm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index 4935e910d6..bfe1146a90 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -536,9 +536,6 @@ def _test_batchnorm2d_tosa_MI_pipeline( compile_spec=common.get_tosa_compile_spec(), ) .export() - .check_count( - {"torch.ops.aten._native_batch_norm_legit_no_training.default": 1} - ) .check_not(["torch.ops.quantized_decomposed"]) .to_edge() .check_count( From 17ad8d3f69d365a77dda590e1d40497befaa6b69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Wed, 6 Nov 2024 13:37:23 +0100 Subject: [PATCH 15/59] Fix type handling for output types from TOSA reference model (#6660) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I80953a699e4861b901af4b2fb17d47d3d7efcedd Signed-off-by: Per Åstrand --- backends/arm/test/runner_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 3e9d3620cc..d2ee113a5d 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -448,8 +448,11 @@ def run_tosa_ref_model( ), "There are no quantization parameters, check output parameters" tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale + if tosa_ref_output.dtype == np.double: + tosa_ref_output = tosa_ref_output.astype("float32") + # tosa_output is a numpy array, convert to torch tensor for comparison - tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32"))) + tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output)) return tosa_ref_outputs @@ -457,7 +460,9 @@ def run_tosa_ref_model( def prep_data_for_save( data, is_quantized: bool, input_name: str, quant_param: QuantizationParams ): - data_np = np.array(data.detach(), order="C").astype(np.float32) + data_np = np.array(data.detach(), order="C").astype( + f"{data.dtype}".replace("torch.", "") + ) if is_quantized: assert quant_param.node_name in input_name, ( From 026fe0b3cefe18389009a6cb66f9848cb138432b Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Wed, 6 Nov 2024 13:59:46 -0800 Subject: [PATCH 16/59] Add support for bits16 in ETDump Differential Revision: D65552835 Pull Request resolved: https://github.com/pytorch/executorch/pull/6697 --- devtools/inspector/_inspector_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index f712644303..c2e92f0914 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -112,6 +112,7 @@ def get_scalar_type_size(scalar_type: ScalarType) -> Tuple[torch.dtype, int]: ScalarType.BYTE: (torch.uint8, 1), ScalarType.CHAR: (torch.int8, 1), ScalarType.BOOL: (torch.bool, 1), + ScalarType.BITS16: (torch.uint16, 2), ScalarType.SHORT: (torch.int16, 2), ScalarType.HALF: (torch.float16, 2), ScalarType.INT: (torch.int, 4), From b07386c85d88d679ad79992f583a309cf9941fdd Mon Sep 17 00:00:00 2001 From: Dave Bort Date: Wed, 6 Nov 2024 14:32:40 -0800 Subject: [PATCH 17/59] Remove custom implementation of string_view Differential Revision: D65454239 Pull Request resolved: https://github.com/pytorch/executorch/pull/6651 --- runtime/core/portable_type/string_view.h | 563 +---------------------- 1 file changed, 2 insertions(+), 561 deletions(-) diff --git a/runtime/core/portable_type/string_view.h b/runtime/core/portable_type/string_view.h index 977a0f542d..8e28fa022c 100644 --- a/runtime/core/portable_type/string_view.h +++ b/runtime/core/portable_type/string_view.h @@ -8,572 +8,13 @@ #pragma once -#include -#include +#include -#include - -// TODO(T154113473): Document this file namespace executorch { namespace runtime { namespace etensor { -namespace internal { - -/** - * Reimplementation of std::string_view for C++11. - * Mostly copy pasted from the c10 implementation but modified some to remove - * broader c10 dependencies - */ -template -class basic_string_view final { - public: - using value_type = CharT; - using pointer = CharT*; - using const_pointer = const CharT*; - using reference = CharT&; - using const_reference = const CharT&; - using const_iterator = const CharT*; - using iterator = const_iterator; - using size_type = std::size_t; - - static constexpr size_type npos = size_type(-1); - - constexpr basic_string_view() noexcept : begin_(nullptr), size_(0) {} - - explicit constexpr basic_string_view(const_pointer str, size_type count) - : begin_(str), size_(count) {} - - /* implicit */ constexpr basic_string_view(const_pointer str) - : basic_string_view(str, strlen_(str)) {} - - constexpr const_iterator begin() const noexcept { - return cbegin(); - } - - constexpr const_iterator cbegin() const noexcept { - return begin_; - } - - constexpr const_iterator end() const noexcept { - return cend(); - } - - constexpr const_iterator cend() const noexcept { - return begin_ + size_; - } - - friend constexpr const_iterator begin(basic_string_view sv) noexcept { - return sv.begin(); - } - - friend constexpr const_iterator end(basic_string_view sv) noexcept { - return sv.end(); - } - - constexpr const_reference operator[](size_type pos) const { - return at_(pos); - } - - constexpr const_reference at(size_type pos) const { - ET_CHECK_MSG( - pos >= size_, - "string_view::operator[] or string_view::at() out of range"); - return at_(pos); - } - - constexpr const_reference front() const { - return *begin_; - } - - constexpr const_reference back() const { - return *(begin_ + size_ - 1); - } - - constexpr const_pointer data() const noexcept { - return begin_; - } - - constexpr size_type size() const noexcept { - return size_; - } - - constexpr size_type length() const noexcept { - return size(); - } - - constexpr bool empty() const noexcept { - return size() == 0; - } - - void remove_prefix(size_type n) { - ET_CHECK_MSG(n > size(), "basic_string_view::remove_prefix: out of range."); - begin_ += n; - size_ -= n; - } - - void remove_suffix(size_type n) { - ET_CHECK_MSG(n > size(), "basic_string_view::remove_suffix: out of range."); - size_ -= n; - } - - void swap(basic_string_view& sv) noexcept { - auto tmp = *this; - *this = sv; - sv = tmp; - } - - size_type copy(pointer dest, size_type count, size_type pos = 0) const { - ET_CHECK_MSG(pos > size_, "basic_string_view::copy: out of range."); - size_type copy_length = min_(count, size_ - pos); - for (auto iter = begin() + pos, end = iter + copy_length; iter != end;) { - *(dest++) = *(iter++); - } - return copy_length; - } - - constexpr basic_string_view substr(size_type pos = 0, size_type count = npos) - const { - ET_CHECK_MSG( - pos > size_, "basic_string_view::substr parameter out of bounds."); - return substr_(pos, count); - } - - constexpr int compare(basic_string_view rhs) const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - for (size_t i = 0, end = min_(size(), rhs.size()); i < end; ++i) { - if (at_(i) < rhs.at_(i)) { - return -1; - } else if (at_(i) > rhs.at_(i)) { - return 1; - } - } - if (size() < rhs.size()) { - return -1; - } else if (size() > rhs.size()) { - return 1; - } - return 0; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (size() == 0 && rhs.size() == 0) ? 0 - : (size() == 0) ? -1 - : (rhs.size() == 0) ? 1 - : (front() < rhs.front()) ? -1 - : (front() > rhs.front()) ? 1 - : substr_(1).compare(rhs.substr_(1)); -#endif - } - - constexpr int compare(size_type pos1, size_type count1, basic_string_view v) - const { - return substr(pos1, count1).compare(v); - } - - constexpr int compare( - size_type pos1, - size_type count1, - basic_string_view v, - size_type pos2, - size_type count2) const { - return substr(pos1, count1).compare(v.substr(pos2, count2)); - } - - constexpr int compare(const_pointer s) const { - return compare(basic_string_view(s)); - } - - constexpr int compare(size_type pos1, size_type count1, const_pointer s) - const { - return substr(pos1, count1).compare(basic_string_view(s)); - } - - constexpr int compare( - size_type pos1, - size_type count1, - const_pointer s, - size_type count2) const { - return substr(pos1, count1).compare(basic_string_view(s, count2)); - } - - friend constexpr bool operator==( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return lhs.equals_(rhs); - } - - friend constexpr bool operator!=( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return !(lhs == rhs); - } - - friend constexpr bool operator<( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return lhs.compare(rhs) < 0; - } - - friend constexpr bool operator>=( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return !(lhs < rhs); - } - - friend constexpr bool operator>( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return rhs < lhs; - } - - friend constexpr bool operator<=( - basic_string_view lhs, - basic_string_view rhs) noexcept { - return !(lhs > rhs); - } - - constexpr bool starts_with(basic_string_view prefix) const noexcept { - return (prefix.size() > size()) ? false - : prefix.equals_(substr_(0, prefix.size())); - } - - constexpr bool starts_with(CharT prefix) const noexcept { - return !empty() && prefix == front(); - } - - constexpr bool starts_with(const_pointer prefix) const { - return starts_with(basic_string_view(prefix)); - } - - constexpr bool ends_with(basic_string_view suffix) const noexcept { - return (suffix.size() > size()) - ? false - : suffix.equals_(substr_(size() - suffix.size(), suffix.size())); - } - - constexpr bool ends_with(CharT suffix) const noexcept { - return !empty() && suffix == back(); - } - - constexpr bool ends_with(const_pointer suffix) const { - return ends_with(basic_string_view(suffix)); - } - - constexpr size_type find(basic_string_view v, size_type pos = 0) - const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - if (v.size() == 0) { - return pos <= size() ? pos : npos; - } - - if (pos + v.size() <= size()) { - for (size_type cur = pos, end = size() - v.size(); cur <= end; ++cur) { - if (v.at_(0) == at_(cur) && - v.substr_(1).equals_(substr_(cur + 1, v.size() - 1))) { - return cur; - } - } - } - return npos; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (v.size() == 0) ? (pos <= size() ? pos : npos) - : (pos + v.size() > size()) ? npos - : (v.at_(0) == at_(pos) && - v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) - ? pos - : find(v, pos + 1); -#endif - } - - constexpr size_type find(CharT ch, size_type pos = 0) const noexcept { - return find_first_if_(pos, charIsEqual_{ch}); - } - - constexpr size_type find(const_pointer s, size_type pos, size_type count) - const { - return find(basic_string_view(s, count), pos); - } - - constexpr size_type find(const_pointer s, size_type pos = 0) const { - return find(basic_string_view(s), pos); - } - - constexpr size_type rfind(basic_string_view v, size_type pos = npos) - const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - if (v.size() == 0) { - return pos <= size() ? pos : size(); - } - - if (v.size() <= size()) { - pos = min_(size() - v.size(), pos); - do { - if (v.at_(0) == at_(pos) && - v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) { - return pos; - } - } while (pos-- > 0); - } - return npos; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (v.size() == 0) ? (pos <= size() ? pos : size()) - : (v.size() > size()) ? npos - : (size() - v.size() < pos) ? rfind(v, size() - v.size()) - : (v.at_(0) == at_(pos) && - v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) - ? pos - : (pos == 0) ? npos - : rfind(v, pos - 1); -#endif - } - - constexpr size_type rfind(CharT ch, size_type pos = npos) const noexcept { - return find_last_if_(pos, charIsEqual_{ch}); - } - - constexpr size_type rfind(const_pointer s, size_type pos, size_type count) - const { - return rfind(basic_string_view(s, count), pos); - } - - constexpr size_type rfind(const_pointer s, size_type pos = npos) const { - return rfind(basic_string_view(s), pos); - } - - constexpr size_type find_first_of(basic_string_view v, size_type pos = 0) - const noexcept { - return find_first_if_(pos, stringViewContainsChar_{v}); - } - - constexpr size_type find_first_of(CharT ch, size_type pos = 0) - const noexcept { - return find_first_if_(pos, charIsEqual_{ch}); - } - - constexpr size_type - find_first_of(const_pointer s, size_type pos, size_type count) const { - return find_first_of(basic_string_view(s, count), pos); - } - - constexpr size_type find_first_of(const_pointer s, size_type pos = 0) const { - return find_first_of(basic_string_view(s), pos); - } - - constexpr size_type find_last_of(basic_string_view v, size_type pos = npos) - const noexcept { - return find_last_if_(pos, stringViewContainsChar_{v}); - } - - constexpr size_type find_last_of(CharT ch, size_type pos = npos) - const noexcept { - return find_last_if_(pos, charIsEqual_{ch}); - } - - constexpr size_type - find_last_of(const_pointer s, size_type pos, size_type count) const { - return find_last_of(basic_string_view(s, count), pos); - } - - constexpr size_type find_last_of(const_pointer s, size_type pos = npos) - const { - return find_last_of(basic_string_view(s), pos); - } - - constexpr size_type find_first_not_of(basic_string_view v, size_type pos = 0) - const noexcept { - return find_first_if_(pos, stringViewDoesNotContainChar_{v}); - } - - constexpr size_type find_first_not_of(CharT ch, size_type pos = 0) - const noexcept { - return find_first_if_(pos, charIsNotEqual_{ch}); - } - - constexpr size_type - find_first_not_of(const_pointer s, size_type pos, size_type count) const { - return find_first_not_of(basic_string_view(s, count), pos); - } - - constexpr size_type find_first_not_of(const_pointer s, size_type pos = 0) - const { - return find_first_not_of(basic_string_view(s), pos); - } - - constexpr size_type find_last_not_of( - basic_string_view v, - size_type pos = npos) const noexcept { - return find_last_if_(pos, stringViewDoesNotContainChar_{v}); - } - - constexpr size_type find_last_not_of(CharT ch, size_type pos = npos) - const noexcept { - return find_last_if_(pos, charIsNotEqual_{ch}); - } - - constexpr size_type - find_last_not_of(const_pointer s, size_type pos, size_type count) const { - return find_last_not_of(basic_string_view(s, count), pos); - } - - constexpr size_type find_last_not_of(const_pointer s, size_type pos = npos) - const { - return find_last_not_of(basic_string_view(s), pos); - } - - private: - static constexpr std::size_t min_(const std::size_t a, const std::size_t b) { - return (b < a) ? b : a; - } - - static constexpr size_type strlen_(const_pointer str) noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - const_pointer current = str; - while (*current != '\0') { - ++current; - } - return current - str; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (*str == '\0') ? 0 : 1 + strlen_(str + 1); -#endif - } - - constexpr const_reference at_(size_type pos) const noexcept { - return *(begin_ + pos); - } - - constexpr basic_string_view substr_(size_type pos = 0, size_type count = npos) - const { - return basic_string_view{begin_ + pos, min_(count, size() - pos)}; - } - - template - constexpr size_type find_first_if_(size_type pos, Condition&& condition) - const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - if (pos + 1 <= size()) { - for (size_type cur = pos; cur < size(); ++cur) { - if (condition(at_(cur))) { - return cur; - } - } - } - return npos; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (pos + 1 > size()) ? npos - : condition(at_(pos)) - ? pos - : find_first_if_(pos + 1, std::forward(condition)); -#endif - } - - template - constexpr size_type find_last_if_(size_type pos, Condition&& condition) - const noexcept { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster. - if (size() > 0) { - pos = min_(size() - 1, pos); - do { - if (condition(at_(pos))) { - return pos; - } - } while (pos-- > 0); - } - return npos; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (size() == 0) ? npos - : (pos >= size()) - ? find_last_if_(size() - 1, std::forward(condition)) - : condition(at_(pos)) ? pos - : (pos == 0) - ? npos - : find_last_if_(pos - 1, std::forward(condition)); -#endif - } - - constexpr bool equals_(basic_string_view rhs) const { -#if __cpp_constexpr >= 201304 - // if we are in C++14, write it iteratively. This is faster than the - // recursive C++11 implementation below. - if (size() != rhs.size()) { - return false; - } - // memcmp would be faster than this loop, but memcmp isn't constexpr - for (typename basic_string_view::size_type pos = 0; pos < size(); - ++pos) { - if (at_(pos) != rhs.at_(pos)) { - return false; - } - } - return true; -#else - // if we are in C++11, we need to do it recursively because of constexpr - // restrictions. - return (size() != rhs.size()) ? false - : (size() == 0) ? true - : (front() != rhs.front()) ? false - : (substr_(1).equals_(rhs.substr_(1))); -#endif - } - - struct charIsEqual_ final { - CharT expected; - constexpr bool operator()(CharT actual) const noexcept { - return expected == actual; - } - }; - - struct charIsNotEqual_ final { - CharT expected; - constexpr bool operator()(CharT actual) const noexcept { - return expected != actual; - } - }; - - struct stringViewContainsChar_ final { - basic_string_view expected; - constexpr bool operator()(CharT ch) const noexcept { - return npos != expected.find(ch); - } - }; - - struct stringViewDoesNotContainChar_ final { - basic_string_view expected; - constexpr bool operator()(CharT ch) const noexcept { - return npos == expected.find(ch); - } - }; - - const_pointer begin_; - size_type size_; -}; - -template -inline void swap( - basic_string_view& lhs, - basic_string_view& rhs) noexcept { - lhs.swap(rhs); -} - -} // namespace internal - -using string_view = internal::basic_string_view; +using std::string_view; } // namespace etensor } // namespace runtime From 785ebf3ff2e6e57aa76320e66a45cec3eb69d117 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 6 Nov 2024 14:47:51 -0800 Subject: [PATCH 18/59] Add trunc scalar prim_op Differential Revision: D65057149 Pull Request resolved: https://github.com/pytorch/executorch/pull/6580 --- exir/passes/executorch_prim_ops_registry.py | 9 +++++++++ kernels/prim_ops/register_prim_ops.cpp | 16 ++++++++++++++++ kernels/prim_ops/test/prim_ops_test.cpp | 20 ++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index 6362a47112..4af233aaa6 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import operator from typing import Dict, Set, Union @@ -14,6 +15,8 @@ from torch._ops import OpOverload from torch.library import Library +# pyre-unsafe + executorch_prims_lib = Library("executorch_prim", "DEF") @@ -91,7 +94,13 @@ def neg(a: _SymScalar) -> _SymScalar: return -a # pyre-ignore +@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar") +def trunc(a: _SymScalar) -> _SymScalar: + return math.trunc(a) # pyre-ignore + + _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = { + math.trunc: ops.backend.executorch_prim.trunc.Scalar, operator.sub: ops.backend.executorch_prim.sub.Scalar, operator.mul: ops.backend.executorch_prim.mul.Scalar, operator.add: ops.backend.executorch_prim.add.Scalar, diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 7872b0d173..5755ab8d66 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -12,6 +12,8 @@ #include #include +#include + using torch::executor::function::et_copy_index; namespace torch { @@ -301,6 +303,20 @@ static Kernel prim_ops[] = { } }), + // trunc.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::trunc.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + out = EValue(static_cast(trunc(a.toDouble()))); + } else { + ET_CHECK_MSG(false, "%zu", (size_t)a.tag); + } + }), + // executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor Kernel("executorch_prim::et_copy_index.tensor", &et_copy_index), // executorch_prim::et_view.default(Tensor, int[]) -> Tensor diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 4b4b35a232..3581a470da 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -503,5 +503,25 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) { getOpsFn("executorch_prim::et_view.default")(context, bad_stack), ""); } +TEST_F(RegisterPrimOpsTest, TestTrunc) { + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 0, 0, 0, 1, 1, 0, -1, -1, 9}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::trunc.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + } // namespace executor } // namespace torch From 6051b2fee44c8509a1af0168743a70699a67a2fd Mon Sep 17 00:00:00 2001 From: mcremon-meta <134334895+mcremon-meta@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:49:26 -0800 Subject: [PATCH 19/59] Add per_tensor overload for quantized_conv Differential Revision: D65306801 Pull Request resolved: https://github.com/pytorch/executorch/pull/6648 --- backends/cadence/aot/ops_registrations.py | 54 +++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index d47ea3f21a..fce6ce5736 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -66,6 +66,12 @@ lib.define( "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)" +) +lib.define( + "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" @@ -171,6 +177,54 @@ def quantized_conv_meta( return input.new_empty(output_size, dtype=input.dtype) +@register_fake("cadence::quantized_conv.per_tensor") +def quantized_conv_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, + channel_last: bool = False, +) -> torch.Tensor: + if channel_last: + out_channels, *kernel_size, _ = weight.shape + else: + out_channels, _, *kernel_size = weight.shape + + in_size = input.shape + # Assert that the input tensor has at least 3 dimensions, and at most 6 + assert len(in_size) > 2 + assert len(in_size) < 6 + + # Compute the output tensor size + output_size = ( + get_conv1d_output_size( + in_size, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size[0], + channel_last, + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, channel_last + ) + ) + + return input.new_empty(output_size, dtype=input.dtype) + + @register_fake("cadence::quantized_layer_norm") def quantized_layer_norm_meta( input: torch.Tensor, From 8f82198b53b66559fca6da86d9a123264829a65d Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Wed, 6 Nov 2024 21:18:48 -0500 Subject: [PATCH 20/59] Add dvorjackz to ghstack_land.yml (#6644) --- .github/workflows/ghstack_land.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ghstack_land.yml b/.github/workflows/ghstack_land.yml index 12782c66dd..e3b02d2a94 100644 --- a/.github/workflows/ghstack_land.yml +++ b/.github/workflows/ghstack_land.yml @@ -5,6 +5,7 @@ on: branches: - 'gh/cccclai/[0-9]+/base' - 'gh/dbort/[0-9]+/base' + - 'gh/dvorjackz/[0-9]+/base' - 'gh/guangy10/[0-9]+/base' - 'gh/helunwencser/[0-9]+/base' - 'gh/jorgep31415/[0-9]+/base' From 70f15e6f9fba40918777ef33643839d13849273c Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 6 Nov 2024 19:20:22 -0800 Subject: [PATCH 21/59] [ET-VK] Fake u16vecn for devserver (#6704) Pull Request resolved: https://github.com/pytorch/executorch/pull/6675 ## Context Copy-pasted from the newly added `maybe_fake_u16vec3` function in the codegen script: > There is a latency benefit to using u16vecn variables to store texture position variables instead of ivecn, likely due to reduced register pressure. However, SwiftShader does not support 16 bit integer types in shaders, so this is a crude way to fallback to using ivecn to store texture positions so that testing with SwiftShader is still possible. ghstack-source-id: 252234981 @exported-using-ghexport Differential Revision: [D65501674](https://our.internmc.facebook.com/intern/diff/D65501674/) Co-authored-by: Stephen Jia --- backends/vulkan/runtime/gen_vulkan_spv.py | 26 ++++++++++++++++++- .../runtime/graph/ops/glsl/q_8w_linear.glsl | 2 ++ backends/vulkan/targets.bzl | 1 + 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 46db1e3a98..39d023e765 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -540,6 +540,7 @@ def __init__( env: Dict[Any, Any], glslc_path: Optional[str], glslc_flags: str = "", + replace_u16vecn: bool = False, ) -> None: if isinstance(src_dir_paths, str): self.src_dir_paths = [src_dir_paths] @@ -549,6 +550,7 @@ def __init__( self.env = env self.glslc_path = glslc_path self.glslc_flags = glslc_flags + self.replace_u16vecn = replace_u16vecn self.glsl_src_files: Dict[str, str] = {} self.template_yaml_files: List[str] = [] @@ -705,6 +707,22 @@ def constructOutputMap(self) -> None: self.create_shader_params(), ) + def maybe_replace_u16vecn(self, input_text: str) -> str: + """ + There is a latency benefit to using u16vecn variables to store texture position + variables instead of ivecn, likely due to reduced register pressure. However, + SwiftShader does not support 16 bit integer types in shaders, so this is a crude + way to fallback to using ivecn to store texture positions so that testing with + SwiftShader is still possible. + """ + if not self.replace_u16vecn: + return input_text + if "codegen-nosub" in input_text: + return input_text + + input_text = input_text.replace("u16vec", "ivec") + return input_text + def generateSPV(self, output_dir: str) -> Dict[str, str]: output_file_map = {} @@ -716,6 +734,7 @@ def process_shader(shader_paths_pair): with codecs.open(source_glsl, "r", encoding="utf-8") as input_file: input_text = input_file.read() + input_text = self.maybe_replace_u16vecn(input_text) output_text = preprocess(input_text, shader_params) glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl") @@ -1029,6 +1048,7 @@ def main(argv: List[str]) -> int: parser.add_argument("-c", "--glslc-path", required=True, help="") parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp") parser.add_argument("-o", "--output-path", required=True, help="") + parser.add_argument("--replace-u16vecn", action="store_true", default=False) parser.add_argument("--optimize_size", action="store_true", help="") parser.add_argument("--optimize", action="store_true", help="") parser.add_argument( @@ -1056,7 +1076,11 @@ def main(argv: List[str]) -> int: glslc_flags += "-O" shader_generator = SPVGenerator( - options.glsl_paths, env, options.glslc_path, glslc_flags + options.glsl_paths, + env, + options.glslc_path, + glslc_flags=glslc_flags, + replace_u16vecn=options.replace_u16vecn, ) output_spv_files = shader_generator.generateSPV(options.tmp_dir_path) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index ecfb44d431..f679732ddb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +// codegen-nosub + #version 450 core #define PRECISION ${PRECISION} diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 9521bcacdb..2c4671afa0 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -27,6 +27,7 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False): select({ "DEFAULT": "", "ovr_config//os:android": "--optimize", + "ovr_config//os:linux": "--replace-u16vecn", }) ) From 545535b63ad82f852ca0043570386f17b7af9e89 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 6 Nov 2024 19:43:19 -0800 Subject: [PATCH 22/59] [Executorch] enable sleef consistently (#6705) Pull Request resolved: https://github.com/pytorch/executorch/pull/6524 Earlier only android platofrms had support for sleef ghstack-source-id: 252186435 @exported-using-ghexport //oss lint broken on unrelated issue @bypass-github-export-checks @exported-using-ghexport Differential Revision: [D64571782](https://our.internmc.facebook.com/intern/diff/D64571782/) Co-authored-by: Kimish Patel --- extension/llm/custom_ops/targets.bzl | 9 +++- kernels/optimized/lib_defs.bzl | 46 ++++++++++++++---- kernels/optimized/op_registration_util.bzl | 7 +-- kernels/optimized/test/targets.bzl | 4 +- .../executorch/kernels/optimized/lib_defs.bzl | 48 +++++++++++++++---- .../optimized/op_registration_util.bzl | 4 +- 6 files changed, 91 insertions(+), 27 deletions(-) diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 6b9f9cb959..781225afed 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -1,10 +1,14 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load( + "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", + "get_vec_preprocessor_flags", + "get_vec_deps", +) load( "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "get_compiler_optimization_flags", ) - def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -26,6 +30,7 @@ def define_common_targets(): "op_sdpa.h", "op_update_quantized_cache.h", ], + preprocessor_flags = get_vec_preprocessor_flags(), exported_deps = [ "//executorch/runtime/kernel:kernel_includes", "//executorch/kernels/portable/cpu:scalar_utils", @@ -38,7 +43,7 @@ def define_common_targets(): deps = [ "//executorch/kernels/portable/cpu/util:reduce_util", "//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform", - ], + ] + get_vec_deps(), compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(), visibility = [ "//executorch/...", diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index fb1c9a17f9..659c7afe09 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -15,16 +15,44 @@ load( # functions in order to declare the required compiler flags needed in order to # access CPU vector intrinsics. -def get_vec_android_preprocessor_flags(): - preprocessor_flags = [ - ( - "^android-arm64.*$", - [ +def get_vec_preprocessor_flags(): + if not runtime.is_oss: + # various ovr_configs are not available in oss + preprocessor_flags = select({ + "ovr_config//os:linux-x86_64": [ "-DET_BUILD_ARM_VEC256_WITH_SLEEF", - ], - ), - ] - return preprocessor_flags + ] if not runtime.is_oss else [], + "ovr_config//os:iphoneos-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return preprocessor_flags + return [] + +def get_vec_deps(): + if not runtime.is_oss: + # various ovr_configs are not available in oss + deps = select({ + "ovr_config//os:iphoneos-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return deps + return [] def get_vec_cxx_preprocessor_flags(): preprocessor_flags = [ diff --git a/kernels/optimized/op_registration_util.bzl b/kernels/optimized/op_registration_util.bzl index 6e74836bb7..6839454be2 100644 --- a/kernels/optimized/op_registration_util.bzl +++ b/kernels/optimized/op_registration_util.bzl @@ -2,7 +2,8 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/build:selects.bzl", "selects") load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", - "get_vec_android_preprocessor_flags", + "get_vec_preprocessor_flags", + "get_vec_deps", ) load( "@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", @@ -94,8 +95,8 @@ def define_op_library(name, deps): compiler_flags = ["-Wno-missing-prototypes"] + get_compiler_optimization_flags(), deps = [ "//executorch/runtime/kernel:kernel_includes", - ] + augmented_deps, - fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(), + ] + augmented_deps + get_vec_deps(), + preprocessor_flags = get_vec_preprocessor_flags(), # sleef needs to be added as a direct dependency of the operator target when building for Android, # or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of # dependencies are not transitive diff --git a/kernels/optimized/test/targets.bzl b/kernels/optimized/test/targets.bzl index d2ee2880c6..e4740a9ad7 100644 --- a/kernels/optimized/test/targets.bzl +++ b/kernels/optimized/test/targets.bzl @@ -1,7 +1,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", - "get_vec_android_preprocessor_flags", + "get_vec_preprocessor_flags", "get_vec_cxx_preprocessor_flags", ) load("@fbsource//xplat/executorch/kernels/test:util.bzl", "define_supported_features_lib") @@ -27,7 +27,7 @@ def _lib_test_bin(name, extra_deps = [], in_cpu = False): "//executorch/kernels/optimized{}:{}".format(cpu_path, lib_root), ] + extra_deps, cxx_platform_preprocessor_flags = get_vec_cxx_preprocessor_flags(), - fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(), + preprocessor_flags = get_vec_preprocessor_flags(), ) def define_common_targets(): diff --git a/shim/xplat/executorch/kernels/optimized/lib_defs.bzl b/shim/xplat/executorch/kernels/optimized/lib_defs.bzl index 79ce6b02b3..bd3284c42a 100644 --- a/shim/xplat/executorch/kernels/optimized/lib_defs.bzl +++ b/shim/xplat/executorch/kernels/optimized/lib_defs.bzl @@ -16,16 +16,46 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") # functions in order to declare the required compiler flags needed in order to # access CPU vector intrinsics. -def get_vec_android_preprocessor_flags(): - preprocessor_flags = [ - ( - "^android-arm64.*$", - [ +# This oopy from kernels/optimized/lib_defs.bzl is not necessary. +# This file really needs to be removed +def get_vec_preprocessor_flags(): + if not runtime.is_oss: + # various ovr_configs are not available in oss + preprocessor_flags = select({ + "ovr_config//os:iphoneos": [ "-DET_BUILD_ARM_VEC256_WITH_SLEEF", - ], - ), - ] - return preprocessor_flags + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "-DET_BUILD_ARM_VEC256_WITH_SLEEF", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return preprocessor_flags + return [] + +def get_vec_deps(): + if not runtime.is_oss: + # various ovr_configs are not available in oss + deps = select({ + "ovr_config//os:linux-x86_64": [ + "fbsource//third-party/sleef:sleef", + ] if not runtime.is_oss else [], + "ovr_config//os:iphoneos": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:macos-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "ovr_config//os:android-arm64": [ + "fbsource//third-party/sleef:sleef_arm", + ] if not runtime.is_oss else [], + "DEFAULT": [], + }) + return deps + return [] def get_vec_cxx_preprocessor_flags(): preprocessor_flags = [ diff --git a/shim/xplat/executorch/kernels/optimized/op_registration_util.bzl b/shim/xplat/executorch/kernels/optimized/op_registration_util.bzl index c9fe4ec912..37a68abaa0 100644 --- a/shim/xplat/executorch/kernels/optimized/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/optimized/op_registration_util.bzl @@ -9,7 +9,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/build:selects.bzl", "selects") load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", - "get_vec_android_preprocessor_flags", + "get_vec_preprocessor_flags", ) def op_target(name, deps = []): @@ -98,7 +98,7 @@ def define_op_library(name, deps): deps = [ "//executorch/runtime/kernel:kernel_includes", ] + augmented_deps, - fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(), + preprocessor_flags = get_vec_preprocessor_flags(), # sleef needs to be added as a direct dependency of the operator target when building for Android, # or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of # dependencies are not transitive From 713d8a115ba968e0bf3aadfc22221769365d3d19 Mon Sep 17 00:00:00 2001 From: Adrian Lundell <36153706+AdrianLundell@users.noreply.github.com> Date: Thu, 7 Nov 2024 07:58:41 +0100 Subject: [PATCH 23/59] Add max_pool2d op to Arm backend (#6285) * Add max_pool2d op to Arm backend. - Adds node visitor and unittests - Adds remove_getitem_op pass to convert (maxpool_get inidices + getitem) -> maxpool2d op * Expected failures only for FVP --- backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/arm_partitioner.py | 1 + backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_max_pool2d.py | 77 ++++++ backends/arm/quantizer/arm_quantizer_utils.py | 1 + backends/arm/test/common.py | 11 + backends/arm/test/ops/test_max_pool.py | 248 ++++++++++++++++++ 7 files changed, 341 insertions(+) create mode 100644 backends/arm/operators/op_max_pool2d.py create mode 100644 backends/arm/test/ops/test_max_pool.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b3ddecbc29..a6c9cf1d06 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -43,6 +43,7 @@ from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import ( UnsqueezeScalarPlaceholdersPass, ) +from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_manager import PassManager @@ -58,6 +59,7 @@ def transform_to_backend_pipeline( ): """Apply passes before transforming program to backend""" self.add_pass(CastInt64ToInt32Pass(exported_program)) + self.add_pass(RemoveGetItemPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(RemoveClonePass()) diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 7309287998..bdd4b80f29 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -55,6 +55,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.repeat.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a8ddf1c8f0..5e188aea77 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -20,6 +20,7 @@ op_get_item, op_hardtanh, op_log, + op_max_pool2d, op_mm, op_mul, op_permute, diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py new file mode 100644 index 0000000000..0752d8242f --- /dev/null +++ b/backends/arm/operators/op_max_pool2d.py @@ -0,0 +1,77 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import cast, List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import get_quant_node_args + +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class MaxPool2dVisitor(NodeVisitor): + target = "aten.max_pool2d.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + input_tensor = inputs[0] + kernel_size = inputs[1].special + stride = inputs[2].special + + try: + padding = [*inputs[3].special, *inputs[3].special] + except IndexError: + padding = [0, 0, 0, 0] + + accumulator_type = input_tensor.dtype + + if is_quant_node: + # Accumulator type always is int8 when input tensor is an integer type. + accumulator_type = ts.DType.INT8 + + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + if is_quant_node: + input_zp = get_quant_node_args( + cast(torch.fx.Node, node.all_input_nodes[0]) + ).zp + output_zp = get_quant_node_args(list(node.users)[0]).zp + + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size, + stride=stride, + pad=padding, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + TosaOp.Op().MAX_POOL2D, + [input_tensor.name], + [output.name], + attr, + ) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index a1d7bfe296..4d52b7ddf1 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -147,6 +147,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool: # TODO: remove? torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.avg_pool2d.default, + torch.ops.aten.max_pool2d.default, torch.ops.aten.full.default, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index af44fa4474..b0e2a7f0bb 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -91,6 +91,17 @@ def pytest_sessionfinish(session, exitstatus): # ==== End of Pytest hooks ===== +# ==== Custom Pytest decorators ===== + + +def expectedFailureOnFVP(test_item): + if is_option_enabled("corstone300"): + test_item.__unittest_expecting_failure__ = True + return test_item + + +# ==== End of Custom Pytest decorators ===== + def load_libquantized_ops_aot_lib(): so_ext = { diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py new file mode 100644 index 0000000000..5c48afa3ce --- /dev/null +++ b/backends/arm/test/ops/test_max_pool.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.exir.backend.backend_details import CompileSpec +from parameterized import parameterized + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +test_data_suite = [ + # (test_name, test_data, [kernel_size, stride, padding]) + ("zeros", torch.zeros(1, 1, 4, 8), [2, 2, 1]), + ("ones", torch.ones(1, 16, 50, 32), [4, 2, 0]), + ("rand", torch.rand(1, 16, 52, 16), [4, 3, 0]), +] + +test_data_suite_mult_batches = [ + ("randn", torch.randn(5, 16, 50, 32), [4, 2, 0]), +] + + +class TestMaxPool2d(unittest.TestCase): + """Tests MaxPool2d.""" + + class MaxPool2d(torch.nn.Module): + def __init__( + self, + kernel_size: int | Tuple[int, int], + stride: int | Tuple[int, int], + padding: int | Tuple[int, int], + ): + super().__init__() + self.max_pool_2d = torch.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x): + return self.max_pool_2d(x) + + def _test_maxpool2d_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + ) + .export() + .check(["torch.ops.aten.max_pool2d.default"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + def _test_maxpool2d_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_maxpool2d_tosa_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.tensor], + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_MI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_BI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_u55_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u55_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-300" + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_u85_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u85_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-320" + ) + + @parameterized.expand(test_data_suite_mult_batches) + def test_maxpool2d_tosa_MI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_MI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite_mult_batches) + def test_maxpool2d_tosa_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_BI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite_mult_batches) + @common.expectedFailureOnFVP # TODO: MLETORCH-433 + def test_maxpool2d_tosa_u55_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u55_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-300" + ) + + @parameterized.expand(test_data_suite_mult_batches) + @common.expectedFailureOnFVP # TODO: MLETORCH-433 + def test_maxpool2d_tosa_u85_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u85_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-320" + ) From 4bbe9945b7c221f7b687dbb6754ce4e650c93c05 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Thu, 7 Nov 2024 09:02:57 +0100 Subject: [PATCH 24/59] Run tosa_reference_model using python binding (#6658) This change makes it uneccessary to dump intermediates by default for running the reference_model --- backends/arm/arm_backend.py | 9 +-- backends/arm/test/common.py | 12 ++-- backends/arm/test/misc/test_debug_feats.py | 5 +- backends/arm/test/ops/test_cat.py | 2 +- backends/arm/test/ops/test_select.py | 4 +- backends/arm/test/runner_utils.py | 81 ++++++++++++++++++---- backends/arm/test/tester/arm_tester.py | 15 ++-- examples/arm/setup.sh | 27 ++------ 8 files changed, 96 insertions(+), 59 deletions(-) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 28af583106..db3b368115 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -13,7 +13,7 @@ import logging import os -from typing import final, List, Optional +from typing import cast, final, List, Optional import serializer.tosa_serializer as ts from executorch.backends.arm.arm_vela import vela_compile @@ -31,6 +31,7 @@ from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram +from torch.fx import Node # TOSA backend debug functionality logger = logging.getLogger(__name__) @@ -225,6 +226,7 @@ def preprocess( # noqa: C901 node_visitors = get_node_visitors(edge_program) for node in graph_module.graph.nodes: + node = cast(Node, node) if node.op == "call_function": process_call_function(node, tosa_graph, node_visitors) elif node.op == "placeholder": @@ -236,9 +238,6 @@ def preprocess( # noqa: C901 # any checking of compatibility. dbg_fail(node, tosa_graph, artifact_path) - # TODO: It would be awesome if this dump could somehow be done on top level and not here. - # Problem is that the desc.json has to be created on the tosa_graph object, which we can't - # access from top level. if artifact_path: tag = _get_first_delegation_tag(graph_module) dbg_tosa_dump( @@ -259,6 +258,4 @@ def preprocess( # noqa: C901 else: raise RuntimeError(f"Unknown format {output_format}") - # Continueing from above. Can I put tosa_graph into this function? - # debug_handle_map = ... return PreprocessResult(processed_bytes=binary) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index b0e2a7f0bb..1a155c0323 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -192,19 +192,15 @@ def get_tosa_compile_spec_unbuilt( the compile spec before calling .build() to finalize it. """ if not custom_path: - intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp( - prefix="arm_tosa_" - ) - else: - intermediate_path = custom_path + custom_path = maybe_get_tosa_collate_path() - if not os.path.exists(intermediate_path): - os.makedirs(intermediate_path, exist_ok=True) + if custom_path is not None and not os.path.exists(custom_path): + os.makedirs(custom_path, exist_ok=True) compile_spec_builder = ( ArmCompileSpecBuilder() .tosa_compile_spec() .set_permute_memory_format(permute_memory_to_nhwc) - .dump_intermediate_artifacts_to(intermediate_path) + .dump_intermediate_artifacts_to(custom_path) ) return compile_spec_builder diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 7d9a18a80e..1aa3e82c76 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -107,7 +107,10 @@ def test_numerical_diff_prints(self): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + permute_memory_to_nhwc=True, + custom_path=tempfile.mkdtemp("diff_print_test"), + ), ) .export() .to_edge() diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index 9723ba0f0c..b0a38ce198 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int): def test_cat_4d_tosa_MI(self): square = torch.ones((2, 2, 2, 2)) for dim in range(-3, 3): - test_data = ((square, square), dim) + test_data = ((square, square.clone()), dim) self._test_cat_tosa_MI_pipeline(self.Cat(), test_data) @parameterized.expand(Cat.test_parameters) diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py index fdb2fa1463..6a47c2e66b 100644 --- a/backends/arm/test/ops/test_select.py +++ b/backends/arm/test/ops/test_select.py @@ -93,8 +93,6 @@ def _test_select_tosa_BI_pipeline( .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() - .dump_artifact() - .dump_operator_distribution() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) @@ -162,12 +160,14 @@ def test_select_int_tosa_MI(self, test_data: test_data_t): ) @parameterized.expand(test_data_suite) + @unittest.skip def test_select_copy_tosa_BI(self, test_data: test_data_t): self._test_select_tosa_BI_pipeline( self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" ) @parameterized.expand(test_data_suite) + @unittest.skip def test_select_int_tosa_BI(self, test_data: test_data_t): self._test_select_tosa_BI_pipeline( self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index d2ee113a5d..f3c90eda83 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -17,11 +17,14 @@ import numpy as np import torch +import tosa_reference_model + from torch.export import ExportedProgram from torch.fx.node import Node +from tosa import TosaGraph logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) +logger.setLevel(logging.CRITICAL) class QuantizationParams: @@ -167,7 +170,7 @@ def __init__( ): self.intermediate_path = intermediate_path self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" - assert os.path.exists( + assert self.intermediate_path is None or os.path.exists( self.intermediate_path ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" @@ -323,7 +326,46 @@ def run_corstone( tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) output_shape = self.output_node.args[0][0].meta["val"].shape tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape) - return [tosa_ref_output] + return tosa_ref_output + + def run_tosa_graph( + self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor] + ) -> torch.Tensor: + """Runs the TOSA reference model with inputs and returns the result.""" + data_np = [ + prep_data_for_save( + input, self.is_quantized, self.input_names[i], self.qp_input[i] + ) + for i, input in enumerate(inputs) + ] + # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. + tosa_profile = 0 if self.is_quantized else 1 + debug_mode = "ALL" if logger.level <= logging.DEBUG else None + outputs, status = tosa_reference_model.run( + graph, + data_np, + verbosity=_tosa_refmodel_loglevel(logger.level), + tosa_profile=tosa_profile, + initialize_variable_tensor_from_numpy=1, # True + debug_mode=debug_mode, + ) + + assert ( + status == tosa_reference_model.GraphStatus.TOSA_VALID + ), "Non-valid TOSA given to reference model." + + outputs_torch = [] + for output in outputs: + output = output.astype(np.float32) + if self.is_quantized: + # Need to dequant back to FP32 for comparison with torch output + quant_param = self.qp_output + assert ( + quant_param is not None + ), "There are no quantization parameters, check output parameters" + output = (output - quant_param.zp) * quant_param.scale + outputs_torch.append(torch.from_numpy(output)) + return tuple(outputs_torch) def run_tosa_ref_model( self, @@ -408,21 +450,13 @@ def run_tosa_ref_model( assert ( shutil.which(self.tosa_ref_model_path) is not None ), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}" - loglevel_map = { - logging.INFO: "INFO", - logging.CRITICAL: "LOW", - logging.ERROR: "LOW", - logging.WARNING: "MED", - logging.DEBUG: "HIGH", - logging.NOTSET: "MED", - } - clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0) + cmd_ref_model = [ self.tosa_ref_model_path, "--test_desc", desc_file_path, "-l", - loglevel_map[clamped_logging_level], + _tosa_refmodel_loglevel(logger.level), ] _run_cmd(cmd_ref_model) @@ -458,7 +492,10 @@ def run_tosa_ref_model( def prep_data_for_save( - data, is_quantized: bool, input_name: str, quant_param: QuantizationParams + data: torch.Tensor, + is_quantized: bool, + input_name: str, + quant_param: QuantizationParams, ): data_np = np.array(data.detach(), order="C").astype( f"{data.dtype}".replace("torch.", "") @@ -602,3 +639,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: pass return json_out + + +def _tosa_refmodel_loglevel(loglevel: int) -> str: + """Converts a logging loglevel to tosa_reference_model logginglevel, + returned as string. + """ + loglevel_map = { + logging.INFO: "INFO", + logging.CRITICAL: "LOW", + logging.ERROR: "LOW", + logging.WARNING: "MED", + logging.DEBUG: "HIGH", + logging.NOTSET: "MED", + } + clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) + return loglevel_map[clamped_logging_level] diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 096bc2b22f..834e177b7d 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -39,7 +39,7 @@ from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, EdgeProgramManager from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule @@ -120,10 +120,15 @@ def __init__( super().__init__(dynamic_shapes) self.tosa_test_util = tosa_test_util + def run(self, artifact: EdgeProgramManager, inputs=None): + self.executorch_program = artifact.to_executorch(self.config) + if module := getattr( + artifact.exported_program().graph_module, "lowered_module_0", None + ): + self.buffer = module.processed_bytes + def run_artifact(self, inputs): - tosa_output = self.tosa_test_util.run_tosa_ref_model( - inputs=inputs, - ) + tosa_output = self.tosa_test_util.run_tosa_graph(self.buffer, inputs) return tosa_output @@ -316,7 +321,7 @@ def run_method_and_compare_outputs( logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}") reference_output = reference_stage.run_artifact(reference_input) - test_output = tuple(test_stage.run_artifact(test_input)) + test_output = test_stage.run_artifact(test_input) if ( is_nhwc and test_stage == self.stages[self.stage_name(tester.ToExecutorch)] diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 583237729d..43f7d48b83 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -88,7 +88,7 @@ ethos_u_base_rev="24.08" # tosa reference model tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model" -tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5" +tosa_reference_model_rev="ef31e7222e99cb1c24b2aff9fc52b2d609612283" ######## ### Mandatory user args @@ -227,30 +227,13 @@ function setup_tosa_reference_model() { cd reference_model git checkout ${tosa_reference_model_rev} git submodule update --init --recursive - cd .. - fi - cd reference_model - mkdir -p build - cd build - cmake .. - - # make use of half the cores for building - if [[ "${OS}" == "Linux" ]]; then - n=$(( $(nproc) / 2 )) - elif [[ "${OS}" == "Darwin" ]]; then - n=$(( $(sysctl -n hw.logicalcpu) / 2 )) - else - n=1 fi - if [[ "$n" -lt 1 ]]; then - n=1 - fi + echo "pip installing reference_model..." + repo_dir="${root_dir}/reference_model" + cd $repo_dir + pip install . - make -j"${n}" - cd reference_model - tosa_bin_path=`pwd` - echo "export PATH=\${PATH}:${tosa_bin_path}" >> "${setup_path_script}" } function setup_vela() { From 38346fdd0701638e6a9b3f4be662258a68d09b01 Mon Sep 17 00:00:00 2001 From: cad-audio <86048415+cad-audio@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:03:49 -0800 Subject: [PATCH 25/59] Added HiFi optimized mean and where ops. (#6483) Adding mean and where ops optimized on HiFi Co-authored-by: dijopaul --- backends/cadence/aot/functions_hifi.yaml | 7 +- backends/cadence/hifi/kernels/CMakeLists.txt | 2 + backends/cadence/hifi/kernels/kernels.h | 28 + .../cadence/hifi/operators/CMakeLists.txt | 12 +- backends/cadence/hifi/operators/op_mean.cpp | 170 ++++ backends/cadence/hifi/operators/op_where.cpp | 176 ++++ .../nnlib/xa_nn_elm_where_f32xf32_f32.c | 838 ++++++++++++++++++ .../third-party/nnlib/xa_nn_reduce_32_32.c | 647 ++++++++++++++ 8 files changed, 1870 insertions(+), 10 deletions(-) create mode 100644 backends/cadence/hifi/operators/op_mean.cpp create mode 100644 backends/cadence/hifi/operators/op_where.cpp create mode 100644 backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c create mode 100644 backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 84c07be78c..52390e1918 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -62,6 +62,11 @@ - arg_meta: null kernel_name: torch::executor::full_out +- op: mean.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::mean_dim_out + - op: mul.out kernels: - arg_meta: null @@ -105,7 +110,7 @@ - op: where.self_out kernels: - arg_meta: null - kernel_name: torch::executor::where_out + kernel_name: cadence::impl::HiFi::where_out # custom ops - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) diff --git a/backends/cadence/hifi/kernels/CMakeLists.txt b/backends/cadence/hifi/kernels/CMakeLists.txt index 8fee7e8536..9321cc544e 100644 --- a/backends/cadence/hifi/kernels/CMakeLists.txt +++ b/backends/cadence/hifi/kernels/CMakeLists.txt @@ -13,6 +13,8 @@ add_library( ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c ) # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index 70d5e39fad..2c915661f8 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -55,6 +55,34 @@ extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32( const FLOAT32* __restrict__ p_inp2, const WORD32* const p_inp2_shape); +extern "C" WORD32 xa_nn_elm_where_f32xf32_f32( + FLOAT32* __restrict__ p_out, + const FLOAT32* __restrict__ p_inp1, + const FLOAT32* __restrict__ p_inp2, + const unsigned char* __restrict__ p_condition, + WORD32 num_elm); + +extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32( + FLOAT32* __restrict__ p_out, + const WORD32* const p_out_shape, + const FLOAT32* __restrict__ p_inp1, + const WORD32* const p_inp1_shape, + const FLOAT32* __restrict__ p_inp2, + const WORD32* const p_inp2_shape, + const unsigned char* __restrict__ p_condition, + const WORD32* const p_condition_shape); + +extern "C" WORD32 xa_nn_reduce_mean_4D_f32_f32( + FLOAT32* __restrict__ p_out, + const WORD32* const p_out_shape, + const FLOAT32* __restrict__ p_inp, + const WORD32* const p_inp_shape, + const WORD32* __restrict__ p_axis, + WORD32 num_out_dims, + WORD32 num_inp_dims, + WORD32 num_axis_dims, + void* __restrict__ p_scratch_in); + namespace cadence { namespace impl { namespace HiFi { diff --git a/backends/cadence/hifi/operators/CMakeLists.txt b/backends/cadence/hifi/operators/CMakeLists.txt index cbbb279e5d..dbe5867550 100644 --- a/backends/cadence/hifi/operators/CMakeLists.txt +++ b/backends/cadence/hifi/operators/CMakeLists.txt @@ -22,19 +22,12 @@ endif() set(_aten_ops__srcs "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" @@ -57,6 +50,7 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" ) add_library(aten_ops_cadence ${_aten_ops__srcs}) target_link_libraries(aten_ops_cadence PUBLIC executorch) diff --git a/backends/cadence/hifi/operators/op_mean.cpp b/backends/cadence/hifi/operators/op_mean.cpp new file mode 100644 index 0000000000..478e10da71 --- /dev/null +++ b/backends/cadence/hifi/operators/op_mean.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::aten::RuntimeContext; +using executorch::runtime::ArrayRef; +using torch::executor::Error; +using torch::executor::optional; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +int prepare_data( + const Tensor& in, + Tensor& out, + optional> dim_list, + int* inp_shape, + int* out_shape, + int* p_axis, + int num_inp_dims, + int num_out_dims) { + for (int i = 0; i < num_inp_dims; i++) { + inp_shape[i] = in.size(i); + } + + for (int i = 0; i < num_out_dims; i++) { + out_shape[i] = out.size(i); + } + + int num_axis_dims = 0; + for (const auto& d : dim_list.value()) { + if (d < 0) { + p_axis[num_axis_dims] = num_inp_dims + d; + num_axis_dims++; + } else { + p_axis[num_axis_dims] = d; + num_axis_dims++; + } + } + + return num_axis_dims; +} + +Tensor& mean_dim_out( + RuntimeContext& ctx, + const Tensor& in, + optional> dim_list, + bool keepdim, + optional dtype, + Tensor& out) { + ET_KERNEL_CHECK( + ctx, + torch::executor::check_mean_dim_args(in, dim_list, keepdim, dtype, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_reduction_out(in, dim_list, keepdim, out) == + Error::Ok, + InvalidArgument, + out); + + constexpr auto name = "mean.out"; + constexpr int kNnlibMaxDim = 4; + + bool optimized = 1; + + if (out.scalar_type() != ScalarType::Float) + optimized = 0; + + if (in.dim() > kNnlibMaxDim) + optimized = 0; + + if (optimized) { + float* __restrict__ p_out = out.mutable_data_ptr(); + const float* __restrict__ p_inp = + (const float* __restrict__)in.const_data_ptr(); + + int num_elm = in.numel(); + + int num_inp_dims = in.dim(); + int num_out_dims = out.dim(); + + int inp_shape[kNnlibMaxDim]; + int out_shape[kNnlibMaxDim]; + int p_axis[kNnlibMaxDim]; + + for (int i = 0; i < kNnlibMaxDim; i++) { + out_shape[i] = 1; + inp_shape[i] = 1; + p_axis[i] = 1; + } + + int num_axis_dims = prepare_data( + in, + out, + dim_list, + inp_shape, + out_shape, + p_axis, + num_inp_dims, + num_out_dims); + + if (num_axis_dims == num_inp_dims) { + num_out_dims = 1; + out_shape[0] = 1; + } + + int scratch_size = xa_nn_reduce_getsize_nhwc( + -3, inp_shape, num_inp_dims, p_axis, num_axis_dims, 1); + + void* __restrict__ p_scratch_in = (void* __restrict__)malloc(scratch_size); + + xa_nn_reduce_mean_4D_f32_f32( + p_out, + out_shape, + p_inp, + inp_shape, + p_axis, + num_out_dims, + num_inp_dims, + num_axis_dims, + p_scratch_in); + + return out; + } + + ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] { + ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] { + CTYPE_OUT* out_data = out.mutable_data_ptr(); + const size_t num = torch::executor::get_reduced_dim_product(in, dim_list); + + for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { + CTYPE_OUT sum = 0; + if (in.numel() > 0) { + sum = torch::executor::map_reduce_over_dim_list( + [](CTYPE_IN v) { return static_cast(v); }, + [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, + in, + dim_list, + out_ix); + } + out_data[out_ix] = sum / static_cast(num); + } + }); + }); + + return out; +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_where.cpp b/backends/cadence/hifi/operators/op_where.cpp new file mode 100644 index 0000000000..06bd0bc3c9 --- /dev/null +++ b/backends/cadence/hifi/operators/op_where.cpp @@ -0,0 +1,176 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::aten::RuntimeContext; +using torch::executor::Error; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +Tensor& where_out( + RuntimeContext& ctx, + const Tensor& cond, + const Tensor& a, + const Tensor& b, + Tensor& out) { + ScalarType cond_type = cond.scalar_type(); + ScalarType a_type = a.scalar_type(); + ScalarType b_type = b.scalar_type(); + ScalarType common_type = executorch::runtime::promoteTypes(a_type, b_type); + ScalarType out_type = out.scalar_type(); + + ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + + // Determine output size and resize for dynamic shapes + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, cond, out) == + Error::Ok, + InvalidArgument, + out); + + constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */ + constexpr auto name = "where.self_out"; + + ET_CHECK_MSG( + cond_type == ScalarType::Bool || cond_type == ScalarType::Byte, + "Unhandled dtype %s for where.self_out", + torch::executor::toString(cond_type)); + + int a_dim = a.dim(), b_dim = b.dim(), con_dim = cond.dim(), + out_dim = out.dim(); + bool optimized = 1; + /*find broadcast*/ + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes()); + const bool broadcast = + (a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted); + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + max_dim = cond.dim() > max_dim ? cond.dim() : max_dim; + max_dim = out.dim() > max_dim ? out.dim() : max_dim; + + if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float)) + optimized = 0; + + if ((a_dim == 0) || (b_dim == 0) || (con_dim == 0)) + optimized = 0; + + if ((broadcast == 1) && (max_dim > kNnlibMaxDim)) + optimized = 0; + + if (optimized) { + const float* a_data = a.const_data_ptr(); + const float* b_data = b.const_data_ptr(); + float* out_data = out.mutable_data_ptr(); + const unsigned char* con = cond.const_data_ptr(); + + if (broadcast == 1) { + int out_shape[kNnlibMaxDim]; + int inp1_shape[kNnlibMaxDim]; + int inp2_shape[kNnlibMaxDim]; + int con_shape[kNnlibMaxDim]; + + for (int i = 0; i < kNnlibMaxDim; i++) { + con_shape[i] = 1; + out_shape[i] = 1; + inp1_shape[i] = 1; + inp2_shape[i] = 1; + } + + int off_o = kNnlibMaxDim - out.dim(); + int off_a = kNnlibMaxDim - a.dim(); + int off_b = kNnlibMaxDim - b.dim(); + int off_c = kNnlibMaxDim - cond.dim(); + + for (int i = 0; i < out.dim(); i++) + out_shape[i + off_o] = out.size(i); + for (int i = 0; i < a.dim(); i++) + inp1_shape[i + off_a] = a.size(i); + for (int i = 0; i < b.dim(); i++) + inp2_shape[i + off_b] = b.size(i); + for (int i = 0; i < cond.dim(); i++) + con_shape[i + off_c] = cond.size(i); + + if (con_shape[0] != out_shape[0] || con_shape[1] != out_shape[1] || + con_shape[2] != out_shape[2] || con_shape[3] != out_shape[3]) { + void* p_scratch = + malloc(out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]); + const unsigned char* p_brd_cond = (const unsigned char*)p_scratch; + xa_nn_broadcast_8_8( + (WORD8* __restrict__)p_brd_cond, + out_shape, + (const WORD8* __restrict__)con, + con_shape, + 4); + + for (int i = 0; i < 4; i++) { + con_shape[i] = out_shape[i]; + } + xa_nn_elm_where_broadcast_4D_f32xf32_f32( + out_data, + out_shape, + a_data, + inp1_shape, + b_data, + inp2_shape, + p_brd_cond, + con_shape); + free(p_scratch); + } else { + xa_nn_elm_where_broadcast_4D_f32xf32_f32( + out_data, + out_shape, + a_data, + inp1_shape, + b_data, + inp2_shape, + con, + con_shape); + } + } else { + xa_nn_elm_where_f32xf32_f32(out_data, a_data, b_data, con, out.numel()); + } + return out; + } + ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() { + using CTYPE_OUT = + typename torch::executor::promote_types::type; + torch::executor:: + apply_ternary_elementwise_fn( + [](const CTYPE_A val_a, + const CTYPE_B val_b, + const uint8_t val_c) { + CTYPE_OUT a_casted = static_cast(val_a); + CTYPE_OUT b_casted = static_cast(val_b); + return val_c ? a_casted : b_casted; + }, + a, + b, + cond, + out); + }); + }); + return out; +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c new file mode 100644 index 0000000000..6a7f6d0f77 --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c @@ -0,0 +1,838 @@ +/******************************************************************************* +* Copyright (c) 2018-2024 Cadence Design Systems, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to use this Software with Cadence processor cores only and +* not with any other processors and platforms, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +******************************************************************************/ +#include "xa_type_def.h" +#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nnlib_common_fpu.h" +#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nn_common.h" +#include "nnlib-hifi4/xa_nnlib/algo/common/include/xa_nnlib_err_chk.h" +#include "nnlib-hifi4/xa_nnlib/algo/kernels/basic/hifi4/xa_nn_basic_state.h" +#include "xa_nnlib_kernels_api.h" + + +#if !HAVE_VFPU +DISCARD_FUN_FOR_NONVOID_RETURN( + WORD32, xa_nn_elm_where_f32xf32_f32, + ( + FLOAT32 *p_out, + const FLOAT32 *p_inp1, + const FLOAT32 *p_inp2, + const unsigned char *__restrict__ condition, + WORD32 num_elm + ) + ) +#else +WORD32 xa_nn_elm_where_f32xf32_f32(FLOAT32 * __restrict__ p_out, + const FLOAT32 * __restrict__ p_inp1, + const FLOAT32 * __restrict__ p_inp2, + const unsigned char *__restrict__ p_condition, + WORD32 num_elm) +{ + + /* NULL pointer checks */ + XA_NNLIB_ARG_CHK_PTR(p_out, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp1, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp2, -1); + /* Pointer alignment checks */ + XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp1, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp2, sizeof(FLOAT32), -1); + /* Basic Parameter checks */ + XA_NNLIB_ARG_CHK_COND((num_elm <= 0), -1); + + int i; + xtfloatx2 *inp1 = (xtfloatx2 *)p_inp1; + xtfloatx2 *inp2 = (xtfloatx2 *)p_inp2; + xtfloatx2 *out = (xtfloatx2 *)p_out; + unsigned char *condition = p_condition; + xtfloatx2 x1, x2, y; + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + + if(((((unsigned)p_out)&7) == 0) && ((((unsigned)p_inp1)&7) == 0) && ((((unsigned)p_inp2)&7) == 0)) + { + for(i=0;i < num_elm>>1;i++) + { + XT_LSX2IP(x1, inp1, 2*sizeof(FLOAT32)); + XT_LSX2IP(x2, inp2, 2*sizeof(FLOAT32)); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SSX2IP( y, out, 2*sizeof(FLOAT32)); + } + } + else + { + ae_valign inp1_a, inp2_a, out_a; + + inp1_a = XT_LASX2PP(inp1); + inp2_a = XT_LASX2PP(inp2); + out_a = AE_ZALIGN64(); + /* Each iteration of loop is independent so safe to use concurrent pragma */ +#pragma concurrent + for(i=0;i < num_elm>>1;i++) + { + XT_LASX2IP(x1, inp1_a, inp1); + XT_LASX2IP(x2, inp2_a, inp2); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SASX2IP(y, out_a, out); + } + XT_SASX2POSFP(out_a, out); + } + // Remainder Loop + if (num_elm & 1) + { + xtfloat a1, a2, a; + con1 = XT_L8UI(condition, 0); + xtbool s = AE_MOVBA(con1); + XT_LSIP(a1, (xtfloat *)inp1, 0); + XT_LSIP(a2, (xtfloat *)inp2, 0); + XT_MOVT_S(a, a1, s); + XT_MOVF_S(a, a2, s); + XT_SSI(a, (xtfloat *)out, 0); + } +} + +static void internal_elm_where_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out, + const FLOAT32 * __restrict__ p_inp1, + const FLOAT32 * __restrict__ p_inp2, + const unsigned char * __restrict__ p_condition, + WORD32 num_elm, + xtbool sign_flag) +{ + int i; + xtfloatx2 * __restrict__ p_a = (xtfloatx2 *)p_inp1; + xtfloatx2 * __restrict__ p_b = (xtfloatx2 *)p_inp2; + xtfloatx2 *__restrict__ p_c = (xtfloatx2 *)p_out; + unsigned char *condition = p_condition; + + const int num_simd2_ops = num_elm >> 1; + const int num_scalar_ops = num_elm & 1; + + xtfloat a0_7, out; + xtfloatx2 x1, x2, y; + x2 = XT_LSI((xtfloat *)p_b, 0); + + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + + /* For out = condition ? inp2 :inp1 */ + if(sign_flag){ + if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_c)&7) == 0)) + { + for(i=0; i> 1; + const int num_scalar_ops = num_elm & 1; + + xtfloat a0_7, out; + xtfloatx2 x1, x2, y; + x2 = XT_LSI((xtfloat *)p_b, 0); + x1 = XT_LSI((xtfloat *)p_a, 0); + + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + + if((((unsigned)p_c)&7) == 0) + { + for(i=0; i> 1; + num_scalar_ops = in_lc & 1; + } + else + { + num_simd2_ops = (in_lc >> 2) << 1; + num_scalar_ops = in_lc & 3; + } + + xtfloatx2 x1, x2, y; + xtfloat a0, b0, c0; + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + /* For out = condition ? inp2 :inp1 */ + if(sign_flag){ + for(i = 0; i < out_lc; i++) + { + p_a = (xtfloatx2 *)&p_inp1[i * in_lc]; + p_b = (xtfloatx2 *)p_inp2; + p_c = (xtfloatx2 *)&p_out[i * in_lc]; + condition = &p_condition[i * in_lc]; + if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0)) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); + XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32)); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x2, con); + XT_MOVF_SX2 (y, x1, con); + XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); + } + } + else + { + ae_valign vinp1, vinp2, out_a = AE_ZALIGN64(); + vinp1 = XT_LASX2PP(p_a); + vinp2 = XT_LASX2PP(p_b); + for(j = 0; j < num_simd2_ops; j++) + { + XT_LASX2IP(x1, vinp1, p_a); + XT_LASX2IP(x2, vinp2, p_b); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x2, con); + XT_MOVF_SX2 (y, x1, con); + XT_SASX2IP(y, out_a, p_c); + } + XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c); + } + if(num_scalar_ops !=0) + { + XT_LSIP(a0, (xtfloat *)p_a, 0); + XT_LSIP(b0, (xtfloat *)p_b, 0); + con1 = XT_L8UI(condition, 0); + xtbool s = AE_MOVBA(con1); + XT_MOVT_S(c0, b0, s); + XT_MOVF_S(c0, a0, s); + XT_SSI(c0, (xtfloat *)p_c, 0); + } + } + } + /* For out = condition ? inp1 :inp2 */ + else + { + for(i = 0; i < out_lc; i++) + { + p_a = (xtfloatx2 *)&p_inp1[i * in_lc]; + p_b = (xtfloatx2 *)p_inp2; + p_c = (xtfloatx2 *)&p_out[i * in_lc]; + condition = &p_condition[i * in_lc]; + if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0)) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); + XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32)); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); + } + } + else + { + ae_valign vinp1, vinp2, out_a = AE_ZALIGN64(); + vinp1 = XT_LASX2PP(p_a); + vinp2 = XT_LASX2PP(p_b); + + for(j = 0; j < num_simd2_ops; j++) + { + XT_LASX2IP(x1, vinp1, p_a); + XT_LASX2IP(x2, vinp2, p_b); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SASX2IP(y, out_a, p_c); + } + XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c); + } + if(num_scalar_ops !=0) + { + XT_LSIP(a0, (xtfloat *)p_a, 0); + XT_LSIP(b0, (xtfloat *)p_b, 0); + con1 = XT_L8UI(condition, 0); + xtbool s = AE_MOVBA(con1); + XT_MOVT_S(c0, a0, s); + XT_MOVF_S(c0, b0, s); + XT_SSI(c0, (xtfloat *)p_c, 0); + } + } + } +} + +static void internal_elm_where_broadcast_both_2D_f32xf32_f32(FLOAT32 * __restrict__ p_out, + const FLOAT32 * __restrict__ p_inp1, + const FLOAT32 * __restrict__ p_inp2, + const unsigned char * __restrict__ p_condition, + WORD32 out_lc, + WORD32 in_lc) +{ + int i, j; + + xtfloatx2 * __restrict__ p_a = (xtfloatx2 *)p_inp1; + xtfloatx2 * __restrict__ p_b = (xtfloatx2 *)p_inp2; + xtfloatx2 *__restrict__ p_c = (xtfloatx2 *)p_out; + unsigned char *condition = p_condition; + + int num_simd2_ops; + int num_scalar_ops; + + if(out_lc) + { + num_simd2_ops = in_lc >> 1; + num_scalar_ops = in_lc & 1; + } + else + { + num_simd2_ops = (in_lc >> 2) << 1; + num_scalar_ops = in_lc & 3; + } + + xtfloatx2 x1, x2, y; + xtfloat a0, b0, c0; + unsigned char con1, con2; + xtbool2 con = int32_rtor_xtbool2(0x00000003); + + for(i = 0; i < out_lc; i++) + { + p_a = (xtfloatx2 *)p_inp1; + p_b = (xtfloatx2 *)p_inp2; + p_c = (xtfloatx2 *)&p_out[i * in_lc]; + condition = &p_condition[i * in_lc]; + if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0)) + { + for(j = 0; j < num_simd2_ops; j++) + { + XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32)); + XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32)); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32)); + } + } + else + { + ae_valign vinp1, vinp2, out_a = AE_ZALIGN64(); + vinp1 = XT_LASX2PP(p_a); + vinp2 = XT_LASX2PP(p_b); + + for(j = 0; j < num_simd2_ops; j++) + { + XT_LASX2IP(x1, vinp1, p_a); + XT_LASX2IP(x2, vinp2, p_b); + con1 = XT_L8UI(condition, 0); + condition++; + con2 = XT_L8UI(condition, 0); + condition++; + con = AE_MOVBA1X2(con1, con2); + XT_MOVT_SX2 (y, x1, con); + XT_MOVF_SX2 (y, x2, con); + XT_SASX2IP(y, out_a, p_c); + } + XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c); + } + if(num_scalar_ops !=0) + { + XT_LSIP(a0, (xtfloat *)p_a, 0); + XT_LSIP(b0, (xtfloat *)p_b, 0); + con1 = XT_L8UI(condition, 0); + xtbool s = AE_MOVBA(con1); + XT_MOVT_S(c0, a0, s); + XT_MOVF_S(c0, b0, s); + XT_SSI(c0, (xtfloat *)p_c, 0); + } + } +} + +WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out, + const WORD32 *const p_out_shape, + const FLOAT32 * __restrict__ p_inp1, + const WORD32 *const p_inp1_shape, + const FLOAT32 * __restrict__ p_inp2, + const WORD32 *const p_inp2_shape, + const unsigned char *__restrict__ p_condition, + const WORD32 *const p_condition_shape + ) +{ + /* NULL pointer checks */ + XA_NNLIB_ARG_CHK_PTR(p_out, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp1, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp2, -1); + XA_NNLIB_ARG_CHK_PTR(p_condition, -1); + XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp1_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp2_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_condition_shape, -1); + /* Pointer alignment checks */ + XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp1, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp2, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_condition, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp1_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp2_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_condition_shape, sizeof(WORD32), -1); + + /* Check shapes */ + int i; + xtbool sign_flag; + for(i = 0; i < 4; i++) + { + if((p_inp1_shape[i] != p_inp2_shape[i]) && ((p_inp1_shape[i] != 1) && (p_inp2_shape[i] != 1))) + { + return -1; + } + } + WORD32 inp1_strides[4], inp2_strides[4]; + inp1_strides[3] = 1; + inp2_strides[3] = 1; + for(i = 2; i >= 0; i--) + { + ae_int32x2 d_str, d_shape; + d_str = AE_MOVDA32X2(inp1_strides[i + 1], inp2_strides[i + 1]); + d_shape = AE_MOVDA32X2(p_inp1_shape[i + 1], p_inp2_shape[i + 1]); + d_str = AE_MULP32X2(d_str, d_shape); + inp1_strides[i] = AE_MOVAD32_H(d_str); + inp2_strides[i] = AE_MOVAD32_L(d_str); + } + + int need_broadcast = 0; + int inp1_const = 1, inp2_const = 1; + for(i = 0; i < 4; i++) + { + if(p_inp1_shape[i] == 1) + { + inp1_strides[i] = 0; + need_broadcast = 1; + } + else + { + inp1_const &= 0; + } + if(p_inp2_shape[i] == 1) + { + inp2_strides[i] = 0; + need_broadcast = 1; + } + else + { + inp2_const &= 0; + } + } + + int itr0, itr1, itr2; + FLOAT32 *p_out_tmp = p_out; + const unsigned char *__restrict p_condition_temp = p_condition; + const FLOAT32 *__restrict__ p_inp1_tmp = p_inp1; + const FLOAT32 *__restrict__ p_inp2_tmp = p_inp2; + + if(need_broadcast == 0) + { + sign_flag = 0; + internal_elm_where_broadcast_2D_f32xf32_f32( + p_out, + p_inp1, + p_inp2, + p_condition, + 1, + p_out_shape[0] * inp1_strides[0], + sign_flag); + } + else if((inp1_strides[3] == 1)&& (inp2_strides[3] == 1)) + { + WORD32 in_lc, out_lc; + sign_flag = 0; + in_lc = p_out_shape[2] * p_out_shape[3]; + out_lc = 1; + if((inp1_strides[2] == 0) && (inp2_strides[2] == 0)) + { + in_lc = p_out_shape[3]; + out_lc = p_out_shape[2]; + for(itr0 = 0; itr0 < p_out_shape[0]; itr0++) + { + const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp; + const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp; + for(itr1 = 0; itr1 < p_out_shape[1]; itr1++) + { + internal_elm_where_broadcast_both_2D_f32xf32_f32( + p_out_tmp, + p_inp1_tmp0, + p_inp2_tmp0, + p_condition_temp, + out_lc, + in_lc); + p_out_tmp += in_lc * out_lc; + p_inp1_tmp0 += inp1_strides[1]; + p_inp2_tmp0 += inp2_strides[1]; + p_condition_temp += in_lc * out_lc; + } + p_inp1_tmp += inp1_strides[0]; + p_inp2_tmp += inp2_strides[0]; + } + } + else + { + if(inp1_strides[2] == 0) + { + const FLOAT32 *tmp; + tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp; + sign_flag = 1; + int tmp_strides[2]; + tmp_strides[0] = inp1_strides[0]; + tmp_strides[1] = inp1_strides[1]; + + inp1_strides[0] = inp2_strides[0]; + inp1_strides[1] = inp2_strides[1]; + + inp2_strides[0] = tmp_strides[0]; + inp2_strides[1] = tmp_strides[1]; + in_lc = p_out_shape[3]; + out_lc = p_out_shape[2]; + } + else if(inp2_strides[2] == 0) + { + in_lc = p_out_shape[3]; + out_lc = p_out_shape[2]; + } + + for(itr0 = 0; itr0 < p_out_shape[0]; itr0++) + { + const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp; + const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp; + for(itr1 = 0; itr1 < p_out_shape[1]; itr1++) + { + internal_elm_where_broadcast_2D_f32xf32_f32( + p_out_tmp, + p_inp1_tmp0, + p_inp2_tmp0, + p_condition_temp, + out_lc, + in_lc, + sign_flag); + p_out_tmp += in_lc * out_lc; + p_inp1_tmp0 += inp1_strides[1]; + p_inp2_tmp0 += inp2_strides[1]; + p_condition_temp += in_lc * out_lc; + } + + p_inp1_tmp += inp1_strides[0]; + p_inp2_tmp += inp2_strides[0]; + } + } + } + else if(inp1_const == 1 || inp2_const == 1) + { + if((inp1_const == 1)&&(inp2_const == 1)) + { + internal_elm_where_broadcast_both_f32xf32_f32( + p_out_tmp, + p_inp1_tmp, + p_inp2_tmp, + p_condition_temp, + p_out_shape[0] * p_out_shape[1] * p_out_shape[2] * p_out_shape[3]); + } + else + { + sign_flag = 0; + if(inp1_strides[3] == 0) + { + sign_flag = 1; + const FLOAT32 *tmp; + tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp; + } + internal_elm_where_broadcast_f32xf32_f32( + p_out_tmp, + p_inp1_tmp, + p_inp2_tmp, + p_condition_temp, + p_out_shape[0] * p_out_shape[1] * p_out_shape[2] * p_out_shape[3], + sign_flag); + } + } + else + { + sign_flag = 0; + if((inp1_strides[3] == 0) && (inp2_strides[3] == 0)) + { + for(itr0 = 0; itr0 < p_out_shape[0]; itr0++) + { + const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp; + const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp; + for(itr1 = 0; itr1 < p_out_shape[1]; itr1++) + { + const FLOAT32 *__restrict__ p_inp1_tmp1 = p_inp1_tmp0; + const FLOAT32 *__restrict__ p_inp2_tmp1 = p_inp2_tmp0; + for(itr2 = 0; itr2 < p_out_shape[2]; itr2++) + { + { + internal_elm_where_broadcast_both_f32xf32_f32( + p_out_tmp, + p_inp1_tmp1, + p_inp2_tmp1, + p_condition_temp, + p_out_shape[3]); + } + p_out_tmp += p_out_shape[3]; + p_inp1_tmp1 += inp1_strides[2]; + p_inp2_tmp1 += inp2_strides[2]; + p_condition_temp += p_out_shape[3]; + } + p_inp1_tmp0 += inp1_strides[1]; + p_inp2_tmp0 += inp2_strides[1]; + } + p_inp1_tmp += inp1_strides[0]; + p_inp2_tmp += inp2_strides[0]; + } + } + else + { + if(inp1_strides[3] == 0) + { + const FLOAT32 *tmp; + tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp; + sign_flag = 1; + int tmp_strides[3]; + tmp_strides[0] = inp1_strides[0]; + tmp_strides[1] = inp1_strides[1]; + tmp_strides[2] = inp1_strides[2]; + + inp1_strides[0] = inp2_strides[0]; + inp1_strides[1] = inp2_strides[1]; + inp1_strides[2] = inp2_strides[2]; + + inp2_strides[0] = tmp_strides[0]; + inp2_strides[1] = tmp_strides[1]; + inp2_strides[2] = tmp_strides[2]; + } + for(itr0 = 0; itr0 < p_out_shape[0]; itr0++) + { + const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp; + const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp; + for(itr1 = 0; itr1 < p_out_shape[1]; itr1++) + { + const FLOAT32 *__restrict__ p_inp1_tmp1 = p_inp1_tmp0; + const FLOAT32 *__restrict__ p_inp2_tmp1 = p_inp2_tmp0; + for(itr2 = 0; itr2 < p_out_shape[2]; itr2++) + { + { + internal_elm_where_broadcast_f32xf32_f32( + p_out_tmp, + p_inp1_tmp1, + p_inp2_tmp1, + p_condition_temp, + p_out_shape[3], + sign_flag); + } + p_out_tmp += p_out_shape[3]; + p_inp1_tmp1 += inp1_strides[2]; + p_inp2_tmp1 += inp2_strides[2]; + p_condition_temp += p_out_shape[3]; + } + p_inp1_tmp0 += inp1_strides[1]; + p_inp2_tmp0 += inp2_strides[1]; + } + p_inp1_tmp += inp1_strides[0]; + p_inp2_tmp += inp2_strides[0]; + } + } + } + return 0; +} + +#endif \ No newline at end of file diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c new file mode 100644 index 0000000000..5978a92d26 --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c @@ -0,0 +1,647 @@ +#include "xa_nnlib_common.h" +#include +//#include "xa_nn_basic_state.h" +#include "xa_nnlib_common_macros.h" + +#define ALIGNMENT_8 8 + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x))+(bytes-1))&(~(bytes-1))) + +static void vecmean16_inpx3(const xtfloatx2 *p_src1, const xtfloat* p_src2, const xtfloat* p_src3, xtfloatx2 *p_dst, int N){ + int i = 0; + ae_valign align_src1, align_dst; + ae_valign align_src2, align_src3; + align_src1 = AE_LA64_PP(p_src1); + align_src2 = AE_LA64_PP(p_src2); + align_src3 = AE_LA64_PP(p_src3); + align_dst = AE_ZALIGN64(); + + for(i=0; i < (N >> 2); i++) + { + xtfloatx2 j1_h, j1_l, j2_h, j2_l; + + xtfloatx2 wout1, wout2; + XT_LASX2IP(wout1, align_src1, p_src1); + XT_LASX2IP(wout2, align_src1, p_src1); + + XT_LASX2IP(j1_h, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j1_l, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j2_h, align_src3, (xtfloatx2 *)p_src3); + XT_LASX2IP(j2_l, align_src3, (xtfloatx2 *)p_src3); + + j1_h = XT_ADD_SX2(j1_h, j2_h); + j1_l = XT_ADD_SX2(j1_l, j2_l); + wout1 = XT_ADD_SX2(wout1, j1_h); + wout2 = XT_ADD_SX2(wout2, j1_l); + + XT_SASX2IP(wout1, align_dst, p_dst); + XT_SASX2IP(wout2, align_dst, p_dst); + } + AE_SA64POS_FP(align_dst, p_dst); // finalize the stream + + //Remainder Loop + for(i=0; i < (N & 3); i++) + { + xtfloat j1, j2; + xtfloat wout1; + XT_LSXP(wout1, (xtfloat *)p_src1, sizeof(xtfloat)); + j1 = (xtfloat) *(p_src2 + i); + j2 = (xtfloat) *(p_src3 + i); + + j1 = XT_ADD_S(j1, j2); + wout1 = XT_ADD_S(wout1, j1); + XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat)); + } +} + +static void vecmean16_inpx2(const xtfloatx2 *p_src1, const xtfloat* p_src2, xtfloatx2 *p_dst, int N){ + ae_valign align_src1, align_dst; + ae_valign align_src2; + align_src1 = AE_LA64_PP(p_src1); + align_src2 = AE_LA64_PP(p_src2); + align_dst = AE_ZALIGN64(); + + int i = 0; + for(i=0; i < (N >> 2); i++) + { + xtfloatx2 j1, j2; + xtfloatx2 wout1, wout2; + XT_LASX2IP(wout1, align_src1, p_src1); + XT_LASX2IP(wout2, align_src1, p_src1); + + XT_LASX2IP(j1, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j2, align_src2, (xtfloatx2 *)p_src2); + + wout1 = XT_ADD_SX2(wout1, j1); + wout2 = XT_ADD_SX2(wout2, j2); + + XT_SASX2IP(wout1, align_dst, p_dst); + XT_SASX2IP(wout2, align_dst, p_dst); + } + AE_SA64POS_FP(align_dst, p_dst); // finalize the stream + + //Remainder Loop + for(i=0; i < (N & 3); i++) + { + xtfloat j1; + xtfloat wout1; + XT_LSXP(wout1, (xtfloat *)p_src1, sizeof(xtfloat)); + j1 = (xtfloat) *(p_src2 + i); + wout1 = XT_ADD_S(wout1, j1); + XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat)); + } +} + +static void vecmean32_inpx3(const xtfloatx2* p_src1, const xtfloatx2* p_wsrc2, const xtfloatx2* p_wsrc3, xtfloatx2 *p_dst, int N){ + ae_valign align_src1, align_src2, align_src3, align_dst; + align_src1 = AE_LA64_PP(p_src1); + align_src2 = AE_LA64_PP(p_wsrc2); + align_src3 = AE_LA64_PP(p_wsrc3); + align_dst = AE_ZALIGN64(); + + int i = 0; + for(i=0; i < (N >> 2); i++) + { + xtfloatx2 j1, j2, j3, j4; + xtfloatx2 wj1, wj2; + xtfloatx2 wout1, wout2; + XT_LASX2IP(wout1, align_src1, p_src1); + XT_LASX2IP(wout2, align_src1, p_src1); + XT_LASX2IP(j1, align_src2, p_wsrc2); + XT_LASX2IP(j2, align_src3, p_wsrc3); + XT_LASX2IP(j3, align_src2, p_wsrc2); + XT_LASX2IP(j4, align_src3, p_wsrc3); + + wj1 = XT_ADD_SX2(j1, j2); + wj2 = XT_ADD_SX2(j3, j4); + wout1 = XT_ADD_SX2(wout1, wj1); + wout2 = XT_ADD_SX2(wout2, wj2); + XT_SASX2IP(wout1, align_dst, p_dst); + XT_SASX2IP(wout2, align_dst, p_dst); + } + AE_SA64POS_FP(align_dst, p_dst); // finalize the stream + + //Remainder Loop + for(i=0; i < (N & 3); i++) + { + xtfloat j1, j2; + xtfloat wj1; + xtfloat wout1; + XT_LSXP(wout1, (xtfloat *)p_src1, 4); + XT_LSXP(j1, (xtfloat *)p_wsrc2, 4); + XT_LSXP(j2, (xtfloat *)p_wsrc3, 4); + wj1 = XT_ADD_S(j1, j2); + wout1 = XT_ADD_S(wout1, wj1); + XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat)); + } +} + +static void vecmean32_inpx2(const xtfloatx2* p_src1, const xtfloatx2* p_wsrc2, xtfloatx2 *p_dst, int N){ + ae_valign align_src1, align_src2, align_dst; + align_src1 = AE_LA64_PP(p_src1); + align_src2 = AE_LA64_PP(p_wsrc2); + align_dst = AE_ZALIGN64(); + + int i = 0; + for(i=0; i < (N >> 2); i++) + { + xtfloatx2 j1, j2; + xtfloatx2 wout1, wout2; + XT_LASX2IP(wout1, align_src1, p_src1); + XT_LASX2IP(wout2, align_src1, p_src1); + XT_LASX2IP(j1, align_src2, p_wsrc2); + XT_LASX2IP(j2, align_src2, p_wsrc2); + wout1 = XT_ADD_SX2(wout1, j1); + wout2 = XT_ADD_SX2(wout2, j2); + XT_SASX2IP(wout1, align_dst, p_dst); + XT_SASX2IP(wout2, align_dst, p_dst); + } + AE_SA64POS_FP(align_dst, p_dst); // finalize the stream + + //Remainder Loop + for(i=0; i < (N & 3); i++) + { + xtfloat j1; + xtfloat wout1; + XT_LSXP(wout1, (xtfloat *)p_src1, 4); + XT_LSXP(j1, (xtfloat *)p_wsrc2, 4); + wout1 = XT_ADD_S(wout1, j1); + XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(WORD32)); + } +} + +static inline void xa_nn_reduce_sum_4D_f32_f32(const FLOAT32 * __restrict__ p_inp + ,const WORD32 *const p_4D_inp_shape + ,const WORD32 * __restrict__ p_axis_data + ,WORD32 num_inp_dims + ,WORD32 num_axis_dims + ,pVOID p_scratch_in) +{ + xtfloat *p_in = (xtfloat *)(p_inp); + xtfloat *p_scratch = (xtfloat *)(p_scratch_in); + + int temp_inp_n = p_4D_inp_shape[0]; + int temp_inp_h = p_4D_inp_shape[1]; + int temp_inp_w = p_4D_inp_shape[2]; + int temp_inp_c = p_4D_inp_shape[3]; + + int itr_axis = 0, itr_n = 0, itr_h = 0, itr_w = 0, itr_c = 0; + xtfloat *p_src2, *p_src3; + xtfloatx2 *p_src1; + xtfloatx2 * p_dst; + ae_valign align_src2; + + int axis_dims_count = num_axis_dims; + if(axis_dims_count) + { + switch(p_axis_data[itr_axis]) + { + case 0: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + for(itr_n=0; itr_n < (temp_inp_n & ~(2 - 1)); itr_n += 2) + { + p_src1 = (xtfloatx2 *)p_scratch; + p_src2 = p_in + itr_n * plane_size; + p_src3 = p_in + (itr_n + 1) * plane_size; + p_dst = (xtfloatx2 *)p_scratch; + vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, plane_size); + } + + if(temp_inp_n & 1) + { + p_src1 = (xtfloatx2 *)p_scratch; + p_src2 = (p_in + itr_n * plane_size); + p_dst = (xtfloatx2 *)p_scratch; + vecmean16_inpx2(p_src1, p_src2, p_dst, plane_size); + } + temp_inp_n = 1; + }break; + case 1: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + for(itr_h=0; itr_h < (temp_inp_h & ~(2 - 1)); itr_h += 2) + { + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size); + p_src3 = p_in + (itr_n * plane_size) + ((itr_h + 1) * wc_plane_size); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, wc_plane_size); + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + } + + if(temp_inp_h & 1) + { + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + vecmean16_inpx2(p_src1, p_src2, p_dst, wc_plane_size); + } + } + temp_inp_h = 1; + }break; + case 2:{ + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + int hc_plane_size = temp_inp_h * temp_inp_c; + + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + for(itr_h=0; itr_h < (temp_inp_h); itr_h++) + { + p_src1 = (xtfloatx2 *)(p_scratch + (((itr_n * hc_plane_size) + itr_h * temp_inp_c))); + for(itr_w=0; itr_w < (temp_inp_w & ~(2 - 1)); itr_w += 2) + { + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c); + p_src3 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + ((itr_w + 1) * temp_inp_c); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c); + vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, temp_inp_c); + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + (itr_h * temp_inp_c)); + } + + if(temp_inp_w & 1) + { + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c); + vecmean16_inpx2(p_src1, p_src2, p_dst, temp_inp_c); + } + } + } + temp_inp_w = 1; + }break; + case 3: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + int hw_plane_size = temp_inp_h * temp_inp_w; + int rem_c = (temp_inp_c & 7); + + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + for(itr_h=0; itr_h < (temp_inp_h); itr_h++) + { + for(itr_w=0; itr_w < (temp_inp_w); itr_w++) + { + p_src1 = (xtfloatx2 *)(p_scratch + (((itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w))); + p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w); + align_src2 = AE_LA64_PP(p_src2); + + for(itr_c=0; itr_c < (temp_inp_c >> 3); itr_c++) + { + xtfloatx2 j11, j12, j21, j22, i1; + i1 = XT_LSX((xtfloat *)p_src1, 0); + XT_LASX2IP(j11, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j12, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j21, align_src2, (xtfloatx2 *)p_src2); + XT_LASX2IP(j22, align_src2, (xtfloatx2 *)p_src2); + + j11 = XT_ADD_SX2(j11, j12); + j21 = XT_ADD_SX2(j21, j22); + + xtfloatx2 t1 = XT_SEL32_HH_SX2(j11, j11); + xtfloatx2 t2 = XT_SEL32_HH_SX2(j21, j21); + + j11 = XT_ADD_SX2(j11, t1); + j21 = XT_ADD_SX2(j21, t2); + + j11 = XT_ADD_SX2(j11, j21); + i1 = XT_ADD_SX2(i1, j11); + + XT_SSX(i1, (xtfloat *)p_dst, 0); + + p_src1 = p_dst; + } + //Remainder Loop + for(itr_c=0; itr_c < rem_c ; itr_c++) + { + xtfloat j1; + xtfloat i1; + i1 = XT_LSX((xtfloat *)p_src1, 0); + j1 = *p_src2++; + + i1 = XT_ADD_S(i1, j1); + XT_SSX(i1, (xtfloat *)p_dst, 0); + } + } + } + } + temp_inp_c = 1; + }break; + default: + break; + } + + axis_dims_count--; + itr_axis++; + } + + while(axis_dims_count) + { + ae_valign align_src; + xtfloat *p_scr_in = p_scratch; + xtfloatx2 *p_wsrc2, *p_wsrc3; + switch(p_axis_data[itr_axis]) + { + case 0: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + for(itr_n=1; itr_n < ((temp_inp_n -1) & ~(2 - 1)); itr_n += 2) + { + p_src1 = (xtfloatx2 *)p_scratch; + p_wsrc2 = (xtfloatx2 *)(p_scr_in + itr_n * plane_size); + p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n + 1) * plane_size); + p_dst = (xtfloatx2 *)p_scratch; + vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, plane_size); + } + + if((temp_inp_n - 1) & 1) + { + p_src1 = (xtfloatx2 *)p_scratch; + p_wsrc2 = (xtfloatx2 *)(p_scr_in + itr_n * plane_size); + p_dst = (xtfloatx2 *)p_scratch; + vecmean32_inpx2(p_src1, p_wsrc2, p_dst, plane_size); + } + temp_inp_n = 1; + }break; + case 1: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + p_src1 = (xtfloatx2 *)(p_scratch + + (itr_n * plane_size)); + for(itr_h = 1; itr_h < ((temp_inp_h - 1) & ~(2 - 1)); itr_h += 2) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size)); + p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + ((itr_h + 1) * wc_plane_size)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, wc_plane_size); + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + } + + if((temp_inp_h - 1) & 1) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size)); + vecmean32_inpx2(p_src1, p_wsrc2, p_dst, plane_size); + } + } + temp_inp_h = 1; + }break; + case 2:{ + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + int hc_plane_size = temp_inp_h * temp_inp_c; + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + for(itr_h=0; itr_h < (temp_inp_h); itr_h++) + { + p_src1 = (xtfloatx2 *)(p_scratch + ((itr_n * plane_size) + (itr_h * wc_plane_size))); + for(itr_w = 1; itr_w < ((temp_inp_w - 1) & ~(2 - 1)); itr_w += 2) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c)); + p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + ((itr_w + 1) * temp_inp_c)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c); + vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, temp_inp_c); + p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + (itr_h * temp_inp_c)); + } + + if((temp_inp_w - 1) & 1) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c); + vecmean32_inpx2(p_src1, p_wsrc2, p_dst, temp_inp_c); + } + } + } + temp_inp_w = 1; + }break; + case 3: { + int plane_size = temp_inp_h * temp_inp_w * temp_inp_c; + int wc_plane_size = temp_inp_w * temp_inp_c; + int hw_plane_size = temp_inp_h * temp_inp_w; + int rem_c = ((temp_inp_c) & 3); + for(itr_n=0; itr_n < (temp_inp_n); itr_n++) + { + for(itr_h=0; itr_h < (temp_inp_h); itr_h++) + { + for(itr_w=0; itr_w < (temp_inp_w); itr_w++) + { + p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c)); + p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w); + align_src = AE_LA64_PP(p_wsrc2); + xtfloatx2 i1 = AE_MOVXTFLOATX2_FROMF32X2(AE_MOVDA32(0)); + for(itr_c = 0; itr_c < (temp_inp_c >> 2); itr_c++) + { + xtfloatx2 j1, j2; + XT_LASX2IP(j1, align_src, p_wsrc2); + XT_LASX2IP(j2, align_src, p_wsrc2); + + xtfloatx2 t1 = XT_SEL32_HH_SX2(j1, j1); + xtfloatx2 t2 = XT_SEL32_HH_SX2(j2, j2); + + j1 = XT_ADD_SX2(t1, j1); + j2 = XT_ADD_SX2(t2, j2); + + i1 = XT_ADD_SX2(i1, j1); + i1 = XT_ADD_SX2(i1, j2); + } + + //Remainder Loop + for(itr_c=0; itr_c < rem_c; itr_c++) + { + xtfloat j1; + XT_LSXP(j1, (xtfloat *)p_wsrc2, sizeof(xtfloat)); + i1 = XT_ADD_S(i1, j1); + } + XT_SSX(i1, (xtfloat *)p_dst, 0); + } + } + } + temp_inp_c = 1; + }break; + default: + break; + } + axis_dims_count--; + itr_axis++; + } +} + +WORD32 xa_nn_reduce_mean_4D_f32_f32( + FLOAT32 * __restrict__ p_out, + const WORD32 *const p_out_shape, + const FLOAT32 * __restrict__ p_inp, + const WORD32 *const p_inp_shape, + const WORD32 * __restrict__ p_axis, + WORD32 num_out_dims, + WORD32 num_inp_dims, + WORD32 num_axis_dims, + void * __restrict__ p_scratch_in) +{ + /* NULL pointer checks */ + XA_NNLIB_ARG_CHK_PTR(p_out, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp, -1); + XA_NNLIB_ARG_CHK_PTR(p_axis, -1); + XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp_shape, -1); + + /* Invalid input checks */ + XA_NNLIB_ARG_CHK_COND(((num_inp_dims <= 0) || (num_inp_dims > 4)), -1); + XA_NNLIB_ARG_CHK_COND(((num_out_dims <= 0) || (num_out_dims > 4)), -1); + XA_NNLIB_ARG_CHK_COND(((num_axis_dims < 0) || (num_axis_dims > 4)), -1); + + int axis_itr = 0, inp_itr = 0, out_itr = 0; + int num_elm_in_axis = 1; + int current, past = -1; + for(axis_itr=0; axis_itr < num_axis_dims; axis_itr++) + { + current = p_axis[axis_itr]; + XA_NNLIB_ARG_CHK_COND(((current < 0) || (current > (num_inp_dims - 1))), -1); + XA_NNLIB_ARG_CHK_COND((p_inp_shape[current] > 1024), -1); + + /* Avoid calculation in case of repeated axis dims*/ + if(current != past) + { + num_elm_in_axis *= p_inp_shape[current]; + past = current; + } + } + + for(inp_itr=0; inp_itr < num_inp_dims; inp_itr++) + { + XA_NNLIB_ARG_CHK_COND((p_inp_shape[inp_itr] <= 0), -1); + } + + int out_length = 1; + for(out_itr=0; out_itr < num_out_dims; out_itr++) + { + XA_NNLIB_ARG_CHK_COND((p_out_shape[out_itr] <= 0), -1); + out_length *= p_out_shape[out_itr]; + } + + /* Pointer alignment checks */ + XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(FLOAT32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_axis, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp_shape, sizeof(WORD32), -1); + + FLOAT32 *p_in = (FLOAT32 *)(p_inp); + WORD32 *p_scratch = (WORD32 *)(ALIGN_PTR(p_scratch_in, ALIGNMENT_8)); + + // Changing order of axis data so that reduce max will be first computed + // across largest inp shape dim in axis. This is required to + // minimize the scratch usage. + int inp_length = 1, p_axis_data[4] = {0}, inp_shape_max; + if(num_axis_dims) + { + inp_shape_max = p_inp_shape[p_axis[0]]; + axis_itr = 1; + int max_axis_itr = 0; + int temp_p_axis_0 = p_axis[0]; + for(axis_itr = 0; axis_itr < num_axis_dims; axis_itr++) + { + p_axis_data[axis_itr] = p_axis[axis_itr]; + } + for(axis_itr = 1; axis_itr < num_axis_dims; axis_itr++) + { + if(p_inp_shape[p_axis[axis_itr]] > inp_shape_max) + { + inp_shape_max = p_inp_shape[p_axis[axis_itr]]; + max_axis_itr = axis_itr; + } + } + p_axis_data[0] = p_axis_data[max_axis_itr]; + p_axis_data[max_axis_itr] = temp_p_axis_0; + + inp_itr = 0; + for(inp_itr=0; inp_itr < num_inp_dims; inp_itr++) + { + inp_length *= p_inp_shape[inp_itr]; + } + + memset(p_scratch, 0, ((inp_length / inp_shape_max) * sizeof(WORD32))); //TODO: Alternate approach for memset? + } + + // Promoting lesser dim tensors to 4D tensors. Also modifying axis + // data accordingly. + int p_4D_inp_shape[4] = {1, 1, 1, 1}; + int itr = num_inp_dims - 1; + int count = 3; + while(itr >= 0) + { + p_4D_inp_shape[count] = p_inp_shape[itr]; + itr--; + count--; + } + for(itr = 0; itr < num_axis_dims; itr++) + { + p_axis_data[itr] = p_axis_data[itr] + (4 - num_inp_dims); + } + ae_valign align_out = AE_ZALIGN64(); + + if(num_axis_dims) + { + if(num_elm_in_axis > 1) + { + xa_nn_reduce_sum_4D_f32_f32(p_in, + p_4D_inp_shape, + p_axis_data, + num_inp_dims, + num_axis_dims, + p_scratch); + itr = 0; + xtfloatx2 *p_src1 = (xtfloatx2 *)(p_scratch); + + float div = 1; + + for(int i = 0; i < num_axis_dims; i++) + { + div = div * (float)p_4D_inp_shape[p_axis_data[i]]; + } + + float mul = 1 / div; + + xtfloatx2 multiplier = XT_LSX((xtfloat *)&mul, 0); + + for(itr = 0; itr < (out_length >> 3); itr++) + { + xtfloatx2 temp1, temp2, temp3, temp4; + + temp2 = XT_LSX2X(p_src1, 8); + temp3 = XT_LSX2X(p_src1, 16); + temp4 = XT_LSX2X(p_src1, 24); + XT_LSX2XP(temp1, p_src1, 32); + + temp1 = XT_MUL_SX2(temp1, multiplier); + temp2 = XT_MUL_SX2(temp2, multiplier); + temp3 = XT_MUL_SX2(temp3, multiplier); + temp4 = XT_MUL_SX2(temp4, multiplier); + + XT_SASX2IP(temp1, align_out, (xtfloatx2 *)p_out); + XT_SASX2IP(temp2, align_out, (xtfloatx2 *)p_out); + XT_SASX2IP(temp3, align_out, (xtfloatx2 *)p_out); + XT_SASX2IP(temp4, align_out, (xtfloatx2 *)p_out); + } + AE_SA64POS_FP(align_out, p_out); + + for(itr = 0; itr < (out_length & 7); itr++) + { + xtfloat temp1; + XT_LSXP(temp1, (xtfloat *)p_src1, 4); + temp1 = XT_MUL_S(temp1, multiplier); + XT_SSXP(temp1, (xtfloat *)p_out, 4); + } + } + else + { + + memcpy(p_out, p_inp, inp_length * sizeof(FLOAT32)); + } + } + else + { + memcpy(p_out, p_inp, inp_length * sizeof(FLOAT32)); + } + + return 0; +} From abc8a5fabddb23ce66a9843a1a176b72589d2e7c Mon Sep 17 00:00:00 2001 From: David Lin Date: Thu, 7 Nov 2024 09:33:07 -0800 Subject: [PATCH 26/59] Revert changes to executor_runner (#6687) fix seg fault Co-authored-by: lind --- .../executor_runner/executor_runner.cpp | 44 +++---------------- examples/portable/executor_runner/targets.bzl | 1 - 2 files changed, 6 insertions(+), 39 deletions(-) diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index 35e58fec03..93c150c0b9 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -22,11 +22,7 @@ #include -#include -#include - #include -#include #include #include #include @@ -40,13 +36,8 @@ DEFINE_string( model_path, "model.pte", "Model serialized in flatbuffer format."); -DEFINE_bool( - is_fd_uri, - false, - "True if the model_path passed is a file descriptor with the prefix \"fd:///\"."); using executorch::extension::FileDataLoader; -using executorch::extension::FileDescriptorDataLoader; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::HierarchicalAllocator; @@ -58,33 +49,6 @@ using executorch::runtime::Program; using executorch::runtime::Result; using executorch::runtime::Span; -static Result getProgram( - const bool is_fd_uri, - const char* model_path) { - // Create a loader to get the data of the program file. This demonstrates both - // FileDataLoader and FileDescriptorDataLoader. There are other DataLoaders - // that use mmap() or point to data that's already in memory, and users can - // create their own DataLoaders to load from arbitrary sources. - if (!is_fd_uri) { - Result loader = FileDataLoader::from(model_path); - - ET_CHECK_MSG( - loader.ok(), - "FileDataLoader::from() failed: 0x%" PRIx32, - (uint32_t)loader.error()); - return Program::load(&loader.get()); - } else { - Result loader = - FileDescriptorDataLoader::fromFileDescriptorUri(model_path); - - ET_CHECK_MSG( - loader.ok(), - "FileDescriptorDataLoader::fromFileDescriptorUri() failed: 0x%" PRIx32, - (uint32_t)loader.error()); - return Program::load(&loader.get()); - } -} - int main(int argc, char** argv) { executorch::runtime::runtime_init(); @@ -102,11 +66,15 @@ int main(int argc, char** argv) { // DataLoaders that use mmap() or point to data that's already in memory, and // users can create their own DataLoaders to load from arbitrary sources. const char* model_path = FLAGS_model_path.c_str(); - const bool is_fd_uri = FLAGS_is_fd_uri; + Result loader = FileDataLoader::from(model_path); + ET_CHECK_MSG( + loader.ok(), + "FileDataLoader::from() failed: 0x%" PRIx32, + (uint32_t)loader.error()); // Parse the program file. This is immutable, and can also be reused between // multiple execution invocations across multiple threads. - Result program = getProgram(is_fd_uri, model_path); + Result program = Program::load(&loader.get()); if (!program.ok()) { ET_LOG(Error, "Failed to parse model file %s", model_path); return 1; diff --git a/examples/portable/executor_runner/targets.bzl b/examples/portable/executor_runner/targets.bzl index 83c63d3a41..9cddaa4ed7 100644 --- a/examples/portable/executor_runner/targets.bzl +++ b/examples/portable/executor_runner/targets.bzl @@ -15,7 +15,6 @@ def define_common_targets(): deps = [ "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", - "//executorch/extension/data_loader:file_descriptor_data_loader", "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/runner_util:inputs", ], From f9698d84b667b886bf5d133c6cb1a7bc6a9328e0 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 7 Nov 2024 12:24:55 -0600 Subject: [PATCH 27/59] Move quantize IO passes from BoltNN to ExecuTorch Differential Revision: D65188297 Pull Request resolved: https://github.com/pytorch/executorch/pull/6686 --- exir/passes/TARGETS | 14 ++ exir/passes/quantize_io_pass.py | 259 ++++++++++++++++++++++++++++ exir/tests/TARGETS | 12 ++ exir/tests/test_quantize_io_pass.py | 156 +++++++++++++++++ 4 files changed, 441 insertions(+) create mode 100644 exir/passes/quantize_io_pass.py create mode 100644 exir/tests/test_quantize_io_pass.py diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index eeb1e5265b..a3251589ac 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -16,6 +16,7 @@ python_library( ":normalize_transpose_pass", ":prim_ops_py_registry", ":quant_fusion_pass", + ":quantize_io_pass", ":remove_noop_pass", ":replace_aten_with_edge_pass", ":replace_broken_ops_with_function_ops_pass", @@ -143,6 +144,19 @@ python_library( ], ) +python_library( + name = "quantize_io_pass", + srcs = [ + "quantize_io_pass.py", + ], + deps = [ + "fbsource//third-party/pypi/numpy:numpy", + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + python_library( name = "memory_planning_pass", srcs = [ diff --git a/exir/passes/quantize_io_pass.py b/exir/passes/quantize_io_pass.py new file mode 100644 index 0000000000..21ac4c868a --- /dev/null +++ b/exir/passes/quantize_io_pass.py @@ -0,0 +1,259 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +import logging +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +import torch + +from executorch.exir import EdgeProgramManager +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass +from executorch.exir.tensor import scalar_type_enum +from torch.fx.passes.infra.pass_base import PassResult + +logger = logging.getLogger(__name__) + + +def quantize_input( + exported_program, input_index, qparams: Optional[Dict[str, Any]] = None +): + """ + Modify the program to expect quantized input at given index. The input is expected + to be quantizing this input as the first step. Must be called before + permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the + expected quantization. + """ + graph = exported_program.graph_module.graph + name = exported_program.graph_signature.user_inputs[input_index] + placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name] + assert placeholders + target_placeholder = placeholders[0] + + if len(target_placeholder.users) != 1: + raise ValueError(f"Input {input_index} has more than one users") + quantize = next(iter(target_placeholder.users)) + if ( + quantize.target + != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ): + raise ValueError(f"Input {input_index} is not used by a quantize op") + + # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op + need_requant = False + if qparams is not None: + assert all( + qparam in qparams for qparam in ["scale", "zp", "dtype"] + ), "dtype/scale/zp must be specified in qparam for input requantization" + if qparams["dtype"] != quantize.args[5]: + if any( + dtype + not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16] + for dtype in [qparams["dtype"], quantize.args[5]] + ): + raise ValueError( + f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}" + ) + + need_requant = True + elif ( + not np.isclose(qparams["scale"], quantize.args[1]) + or qparams["zp"] != quantize.args[2] + ): + need_requant = True + + if need_requant: + assert qparams is not None + dtype = qparams["dtype"] + qmin = torch.iinfo(dtype).min + qmax = torch.iinfo(dtype).max + scale = qparams["scale"] + zero_point = qparams["zp"] + quant_args = (scale, zero_point, qmin, qmax, dtype) + logger.info( + f"Modifying program to requantize quantized input at index {input_index}" + ) + logger.info(f"Quantization parameters: {quant_args}") + + with exported_program.graph_module.graph.inserting_before(quantize): + input_dequant = exported_program.graph_module.graph.call_function( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=( + target_placeholder, + *quant_args, + ), + ) + input_dequant.meta["input_qparams"] = [ + { + "scale": scale, + "zero_point": zero_point, + "qmin": qmin, + "qmax": qmax, + "dtype": dtype, + } + ] + input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32) + target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype) + quantize.replace_input_with(target_placeholder, input_dequant) + else: + quant_args = quantize.args[1:] + logger.info(f"Modifying program to take quantized input at index {input_index}") + logger.info(f"Quantization parameters: {quant_args}") + + target_placeholder.meta["val"] = ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + target_placeholder.meta["val"], *quant_args + ) + ) + quantize.replace_all_uses_with(quantize.args[0]) + + exported_program.graph_module.graph.eliminate_dead_code() + return quant_args + + +def quantize_output(exported_program, output_index): + """ + Modify the program to produce quantized output at given index. The model is expected + to be dequantizing this output as the last step. Must be called before + permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the + output quantization. + """ + graph = exported_program.graph_module.graph + outputs = [n for n in graph.nodes if n.op == "output"] + if len(outputs) != 1: + raise NotImplementedError("Only 1 output node is supported") + + output_node = outputs[0] + output_list = list(output_node.args[0]) + if output_index >= len(output_list): + raise ValueError( + f"{len(output_list)} outputs available, " + + f"output index out of bounds: {output_index}" + ) + + target_output = output_list[output_index] + if ( + target_output.target + != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ): + raise ValueError("Output {output_index} is not a dequantize op") + + dequant = target_output + output_list[output_index] = dequant.args[0] + output_node.args = (output_list,) + dequant_args = dequant.args[1:] + graph.eliminate_dead_code() + + logger.info( + f"Modifying program to produce quantized output at index {output_index}" + ) + logger.info(f"Dequantization parameters: {dequant_args}") + return dequant_args + + +def get_config_method_name( + prefix: Optional[str] = "forward", + arg_type: str = "input", + index: int = 0, + key: str = "scale", +): + if prefix is None: + prefix = "" + else: + prefix = prefix + "_" + assert arg_type in ["input", "output"], "arg_type must be either input or output" + assert index >= 0, "index must be non-negative" + assert key in [ + "scale", + "zp", + "quant_min", + "quant_max", + "dtype", + ], "key must be one of scale, zp, quant_min, quant_max, dtype" + return f"{prefix}{arg_type}{index}_{key}" + + +class QuantizeInputs(ExportPass): + def __init__( + self, + edge_program_manager: EdgeProgramManager, + quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]], + method_name: Optional[str] = None, + ): + super().__init__() + self.edge_program_manager = edge_program_manager + + self.quantized_inputs_idx_dict = {} + if isinstance(quantized_inputs_idx, dict): + self.quantized_inputs_idx_dict = quantized_inputs_idx + else: + for idx in quantized_inputs_idx: + self.quantized_inputs_idx_dict[idx] = None + self.param_prefix_name = method_name + + def call(self, graph_module: torch.fx.GraphModule): + for i, qparams in self.quantized_inputs_idx_dict.items(): + quant_args = quantize_input( + self.edge_program_manager.exported_program(), i, qparams + ) + + if not self.edge_program_manager._config_methods: + self.edge_program_manager._config_methods = {} + + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "scale") + ] = quant_args[0] + self.edge_program_manager._config_methods[ # pyre-ignore + get_config_method_name(self.param_prefix_name, "input", i, "zp") + ] = quant_args[1] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "quant_min") + ] = quant_args[2] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "quant_max") + ] = quant_args[3] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "dtype") + ] = scalar_type_enum(quant_args[4]) + return PassResult(graph_module, True) + + +class QuantizeOutputs(ExportPass): + def __init__( + self, + edge_program_manager: EdgeProgramManager, + quantized_outputs_idx_list: List[int], + method_name: Optional[str] = None, + ): + super().__init__() + self.edge_program_manager = edge_program_manager + self.quantized_outputs_idx_list = quantized_outputs_idx_list + self.param_prefix_name = method_name + + def call(self, graph_module: torch.fx.GraphModule): + for i in self.quantized_outputs_idx_list: + dequant_args = quantize_output( + self.edge_program_manager.exported_program(), i + ) # noqa F841 + + if not self.edge_program_manager._config_methods: + self.edge_program_manager._config_methods = {} + + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "scale") + ] = dequant_args[0] + self.edge_program_manager._config_methods[ # pyre-ignore + get_config_method_name(self.param_prefix_name, "output", i, "zp") + ] = dequant_args[1] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "quant_min") + ] = dequant_args[2] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "quant_max") + ] = dequant_args[3] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "dtype") + ] = scalar_type_enum(dequant_args[4]) + + return PassResult(graph_module, True) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index f8b4d905fb..1995589f80 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -448,3 +448,15 @@ python_unittest( "//executorch/exir:_warnings", ], ) + +python_unittest( + name = "quantize_io_pass", + srcs = [ + "test_quantize_io_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir/passes:quantize_io_pass", + ], +) diff --git a/exir/tests/test_quantize_io_pass.py b/exir/tests/test_quantize_io_pass.py new file mode 100644 index 0000000000..b3899b008c --- /dev/null +++ b/exir/tests/test_quantize_io_pass.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.exir.passes.quantize_io_pass import ( + get_config_method_name, + QuantizeInputs, + QuantizeOutputs, +) +from executorch.exir.tensor import get_scalar_type +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch.testing import FileCheck + +op_str = { + "q": "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", + "dq": "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default", +} + + +class TestQuantIOPass(unittest.TestCase): + class Add(torch.nn.Module): + def forward(self, x, y): + return x + y + + def _quantize(self, mod, example_inputs): + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + m = torch.export.export_for_training( + mod, copy.deepcopy(example_inputs) + ).module() + m = prepare_pt2e(m, quantizer) + _ = m(*example_inputs) + m = convert_pt2e(m) + exported_program = torch.export.export_for_training(m, example_inputs) + return exported_program + + def _check_count(self, op, count, epm): + code = epm.exported_program().graph_module.code + FileCheck().check_count(op, count, exactly=True).run(code) + + def _get_edge_prog_manager(self, mod, example_inputs): + exported_program = self._quantize(mod, example_inputs) + edge_program_manager = to_edge_transform_and_lower( + exported_program, + transform_passes=[], + partitioner=None, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + self._check_count(op_str["dq"], 3, edge_program_manager) + self._check_count(op_str["q"], 3, edge_program_manager) + return edge_program_manager + + def test_add_drop_q_inputs(self) -> None: + example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) + mod = self.Add().eval() + edge_program_manager = self._get_edge_prog_manager(mod, example_inputs) + reference_outputs = edge_program_manager.exported_program().module()( + *example_inputs + ) + + edge_program_manager_qin = edge_program_manager.transform( + [ + QuantizeInputs( + edge_program_manager=edge_program_manager, + quantized_inputs_idx=[0, 1], + method_name="forward", + ) + ] + ) + self._check_count(op_str["dq"], 3, edge_program_manager) + self._check_count(op_str["q"], 1, edge_program_manager) + + quantized_example_inputs = [] + for i in range(len(example_inputs)): + d = edge_program_manager_qin._config_methods + scale = d[get_config_method_name("forward", "input", i, "scale")] + zp = d[get_config_method_name("forward", "input", i, "zp")] + quant_min = d[get_config_method_name("forward", "input", i, "quant_min")] + quant_max = d[get_config_method_name("forward", "input", i, "quant_max")] + dtype = get_scalar_type( + d[get_config_method_name("forward", "input", i, "dtype")] + ) + + quantized_example_inputs.append( + torch.ops.quantized_decomposed.quantize_per_tensor.default( + example_inputs[i], scale, zp, quant_min, quant_max, dtype + ), + ) + quantized_example_inputs = tuple(quantized_example_inputs) + output = edge_program_manager_qin.exported_program().module()( + *quantized_example_inputs + ) + torch.testing.assert_close( + reference_outputs[0], + output[0], + ) + + def test_add_drop_dq_output(self) -> None: + example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) + mod = self.Add().eval() + edge_program_manager = self._get_edge_prog_manager(mod, example_inputs) + reference_outputs = edge_program_manager.exported_program().module()( + *example_inputs + ) + + edge_program_manager_dqout = edge_program_manager.transform( + [ + QuantizeOutputs( + edge_program_manager=edge_program_manager, + quantized_outputs_idx_list=[0], + method_name="forward", + ) + ] + ) + self._check_count(op_str["dq"], 2, edge_program_manager) + self._check_count(op_str["q"], 3, edge_program_manager) + + quantized_outputs = edge_program_manager_dqout.exported_program().module()( + *example_inputs + ) + + dequantized_outputs = [] + for i in range(len(quantized_outputs)): + d = edge_program_manager_dqout._config_methods + scale = d[get_config_method_name("forward", "output", i, "scale")] + zp = d[get_config_method_name("forward", "output", i, "zp")] + q_min = d[get_config_method_name("forward", "output", i, "quant_min")] + q_max = d[get_config_method_name("forward", "output", i, "quant_max")] + dtype = get_scalar_type( + d[get_config_method_name("forward", "output", i, "dtype")] + ) + dequantized_outputs.append( + torch.ops.quantized_decomposed.dequantize_per_tensor.default( + quantized_outputs[i], scale, zp, q_min, q_max, dtype + ) + ) + dequantized_outputs = tuple(dequantized_outputs) + + torch.testing.assert_close( + reference_outputs[0], + dequantized_outputs[0], + ) From 437168ebe5ce04c6203ed62c2488d652a73efbab Mon Sep 17 00:00:00 2001 From: David Lin Date: Thu, 7 Nov 2024 11:04:13 -0800 Subject: [PATCH 28/59] [Android] added tests for Tensor.java Differential Revision: D65608097 Pull Request resolved: https://github.com/pytorch/executorch/pull/6683 --- .../org/pytorch/executorch/TensorTest.java | 270 ++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 extension/android/src/test/java/org/pytorch/executorch/TensorTest.java diff --git a/extension/android/src/test/java/org/pytorch/executorch/TensorTest.java b/extension/android/src/test/java/org/pytorch/executorch/TensorTest.java new file mode 100644 index 0000000000..7933113412 --- /dev/null +++ b/extension/android/src/test/java/org/pytorch/executorch/TensorTest.java @@ -0,0 +1,270 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.pytorch.executorch.Tensor; + +/** Unit tests for {@link Tensor}. */ +@RunWith(JUnit4.class) +public class TensorTest { + + @Test + public void testFloatTensor() { + float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; + long shape[] = {2, 2}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.FLOAT); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsFloatArray()[0], 1e-5); + assertEquals(data[1], tensor.getDataAsFloatArray()[1], 1e-5); + assertEquals(data[2], tensor.getDataAsFloatArray()[2], 1e-5); + assertEquals(data[3], tensor.getDataAsFloatArray()[3], 1e-5); + + FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(4); + floatBuffer.put(data); + tensor = Tensor.fromBlob(floatBuffer, shape); + assertEquals(tensor.dtype(), DType.FLOAT); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsFloatArray()[0], 1e-5); + assertEquals(data[1], tensor.getDataAsFloatArray()[1], 1e-5); + assertEquals(data[2], tensor.getDataAsFloatArray()[2], 1e-5); + assertEquals(data[3], tensor.getDataAsFloatArray()[3], 1e-5); + } + + @Test + public void testIntTensor() { + int data[] = {Integer.MIN_VALUE, 0, 1, Integer.MAX_VALUE}; + long shape[] = {1, 4, 1}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.INT32); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsIntArray()[0]); + assertEquals(data[1], tensor.getDataAsIntArray()[1]); + assertEquals(data[2], tensor.getDataAsIntArray()[2]); + assertEquals(data[3], tensor.getDataAsIntArray()[3]); + + IntBuffer intBuffer = Tensor.allocateIntBuffer(4); + intBuffer.put(data); + tensor = Tensor.fromBlob(intBuffer, shape); + assertEquals(tensor.dtype(), DType.INT32); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsIntArray()[0]); + assertEquals(data[1], tensor.getDataAsIntArray()[1]); + assertEquals(data[2], tensor.getDataAsIntArray()[2]); + assertEquals(data[3], tensor.getDataAsIntArray()[3]); + } + + @Test + public void testDoubleTensor() { + double data[] = {Double.MIN_VALUE, 0.0d, 0.1d, Double.MAX_VALUE}; + long shape[] = {1, 4}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.DOUBLE); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsDoubleArray()[0], 1e-5); + assertEquals(data[1], tensor.getDataAsDoubleArray()[1], 1e-5); + assertEquals(data[2], tensor.getDataAsDoubleArray()[2], 1e-5); + assertEquals(data[3], tensor.getDataAsDoubleArray()[3], 1e-5); + + DoubleBuffer doubleBuffer = Tensor.allocateDoubleBuffer(4); + doubleBuffer.put(data); + tensor = Tensor.fromBlob(doubleBuffer, shape); + assertEquals(tensor.dtype(), DType.DOUBLE); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsDoubleArray()[0], 1e-5); + assertEquals(data[1], tensor.getDataAsDoubleArray()[1], 1e-5); + assertEquals(data[2], tensor.getDataAsDoubleArray()[2], 1e-5); + assertEquals(data[3], tensor.getDataAsDoubleArray()[3], 1e-5); + } + + @Test + public void testLongTensor() { + long data[] = {Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE}; + long shape[] = {4, 1}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.INT64); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsLongArray()[0]); + assertEquals(data[1], tensor.getDataAsLongArray()[1]); + assertEquals(data[2], tensor.getDataAsLongArray()[2]); + assertEquals(data[3], tensor.getDataAsLongArray()[3]); + + LongBuffer longBuffer = Tensor.allocateLongBuffer(4); + longBuffer.put(data); + tensor = Tensor.fromBlob(longBuffer, shape); + assertEquals(tensor.dtype(), DType.INT64); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsLongArray()[0]); + assertEquals(data[1], tensor.getDataAsLongArray()[1]); + assertEquals(data[2], tensor.getDataAsLongArray()[2]); + assertEquals(data[3], tensor.getDataAsLongArray()[3]); + } + + @Test + public void testSignedByteTensor() { + byte data[] = {Byte.MIN_VALUE, (byte) 0, (byte) 1, Byte.MAX_VALUE}; + long shape[] = {1, 1, 4}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.INT8); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsByteArray()[0]); + assertEquals(data[1], tensor.getDataAsByteArray()[1]); + assertEquals(data[2], tensor.getDataAsByteArray()[2]); + assertEquals(data[3], tensor.getDataAsByteArray()[3]); + + ByteBuffer byteBuffer = Tensor.allocateByteBuffer(4); + byteBuffer.put(data); + tensor = Tensor.fromBlob(byteBuffer, shape); + assertEquals(tensor.dtype(), DType.INT8); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsByteArray()[0]); + assertEquals(data[1], tensor.getDataAsByteArray()[1]); + assertEquals(data[2], tensor.getDataAsByteArray()[2]); + assertEquals(data[3], tensor.getDataAsByteArray()[3]); + } + + @Test + public void testUnsignedByteTensor() { + byte data[] = {(byte) 0, (byte) 1, (byte) 2, (byte) 255}; + long shape[] = {4, 1, 1}; + Tensor tensor = Tensor.fromBlobUnsigned(data, shape); + assertEquals(tensor.dtype(), DType.UINT8); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsUnsignedByteArray()[0]); + assertEquals(data[1], tensor.getDataAsUnsignedByteArray()[1]); + assertEquals(data[2], tensor.getDataAsUnsignedByteArray()[2]); + assertEquals(data[3], tensor.getDataAsUnsignedByteArray()[3]); + + ByteBuffer byteBuffer = Tensor.allocateByteBuffer(4); + byteBuffer.put(data); + tensor = Tensor.fromBlobUnsigned(byteBuffer, shape); + assertEquals(tensor.dtype(), DType.UINT8); + assertEquals(shape[0], tensor.shape()[0]); + assertEquals(shape[1], tensor.shape()[1]); + assertEquals(shape[2], tensor.shape()[2]); + assertEquals(4, tensor.numel()); + assertEquals(data[0], tensor.getDataAsUnsignedByteArray()[0]); + assertEquals(data[1], tensor.getDataAsUnsignedByteArray()[1]); + assertEquals(data[2], tensor.getDataAsUnsignedByteArray()[2]); + assertEquals(data[3], tensor.getDataAsUnsignedByteArray()[3]); + } + + @Test + public void testIllegalDataTypeException() { + float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; + long shape[] = {2, 2}; + Tensor tensor = Tensor.fromBlob(data, shape); + assertEquals(tensor.dtype(), DType.FLOAT); + + try { + tensor.getDataAsByteArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + try { + tensor.getDataAsUnsignedByteArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + try { + tensor.getDataAsIntArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + try { + tensor.getDataAsDoubleArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + try { + tensor.getDataAsLongArray(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + // expected + } + } + + @Test + public void testIllegalArguments() { + float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; + long shapeWithNegativeValues[] = {-1, 2}; + long mismatchShape[] = {1, 2}; + + try { + Tensor tensor = Tensor.fromBlob((float[]) null, mismatchShape); + fail("Should have thrown an exception"); + } catch (IllegalArgumentException e) { + // expected + } + try { + Tensor tensor = Tensor.fromBlob(data, null); + fail("Should have thrown an exception"); + } catch (IllegalArgumentException e) { + // expected + } + try { + Tensor tensor = Tensor.fromBlob(data, shapeWithNegativeValues); + fail("Should have thrown an exception"); + } catch (IllegalArgumentException e) { + // expected + } + try { + Tensor tensor = Tensor.fromBlob(data, mismatchShape); + fail("Should have thrown an exception"); + } catch (IllegalArgumentException e) { + // expected + } + } +} From 4af687a76213d4896adcfd88f70c4d8c20179936 Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:33:36 -0800 Subject: [PATCH 29/59] Revert "Qualcomm AI Engine Direct - Quantizer refine for qat (#6513)" (#6722) --- .../qualcomm/quantizer/custom_annotation.py | 6 +- .../observers/per_channel_param_observer.py | 104 ---- backends/qualcomm/quantizer/qconfig.py | 464 ------------------ backends/qualcomm/quantizer/quantizer.py | 140 +++--- .../quantizer/{annotators.py => utils.py} | 447 +++++++++++++++-- backends/qualcomm/tests/test_qnn_delegate.py | 25 +- backends/qualcomm/tests/utils.py | 38 +- backends/qualcomm/utils/utils.py | 2 +- examples/qualcomm/oss_scripts/fastvit.py | 18 +- examples/qualcomm/oss_scripts/llama2/llama.py | 8 +- examples/qualcomm/scripts/export_example.py | 7 +- examples/qualcomm/utils.py | 96 ++-- extension/llm/export/quantizer_lib.py | 19 +- 13 files changed, 584 insertions(+), 790 deletions(-) delete mode 100644 backends/qualcomm/quantizer/observers/per_channel_param_observer.py delete mode 100644 backends/qualcomm/quantizer/qconfig.py rename backends/qualcomm/quantizer/{annotators.py => utils.py} (68%) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 9d6dea8a97..db82172a9e 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -6,12 +6,12 @@ from typing import Sequence import torch -from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, - get_8a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, QuantizationConfig, ) +from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from executorch.exir.dialects._ops import ops as exir_ops from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -110,7 +110,7 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): # Annotate 16a8w for matmul op to get better performance quantization_config_16a8w = get_16a8w_qnn_ptq_config() # Annotate 8a8w for second input of matmul until past_kv_cache - quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: if "nn_module_stack" in node.meta: diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py deleted file mode 100644 index d556dfa4ba..0000000000 --- a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -from torch.ao.quantization.observer import UniformQuantizationObserverBase - - -# TODO move to torch/ao/quantization/observer.py. -class PerChannelParamObserver(UniformQuantizationObserverBase): - def __init__( - self, - ch_axis=0, - use_mse=True, - steps=100, - dtype=torch.int8, - qscheme=torch.per_channel_symmetric, - reduce_range=False, - quant_min=None, - quant_max=None, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps, # noqa: B008 - is_dynamic=False, - **kwargs, - ) -> None: - super().__init__( - dtype=dtype, - qscheme=qscheme, - reduce_range=reduce_range, - quant_min=quant_min, - quant_max=quant_max, - factory_kwargs=factory_kwargs, - eps=eps, - is_dynamic=is_dynamic, - **kwargs, - ) - - factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) - self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) - self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) - self.ch_axis = ch_axis - self.use_mse = use_mse - self.steps = steps - self.calibrated = False - - def to_ch_axis(self, x): - axis_order = list(range(len(x.size()))) - axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis - return torch.flatten(x.permute(axis_order), start_dim=1) - - def mse(self, pred, expect): - loss = (pred - expect).abs().pow(2) - return self.to_ch_axis(loss).mean(1) - - def cosine(self, pred, expect): - target = torch.ones(pred.shape[self.ch_axis]) - pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) - expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) - return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) - - def loss_fn(self, x, new_min, new_max): - scale, offset = self._calculate_qparams(new_min, new_max) - x_q = torch.fake_quantize_per_channel_affine( - x, - scale.data, - offset.data.int(), - self.ch_axis, - self.quant_min, - self.quant_max, - ) - return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) - - def line_search(self, x): - x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) - x_range = torch.max(x_min.abs(), x_max) - optimal_loss = torch.zeros_like(x_min) + 1e9 - - # check which clip range could produce smallest loss - for i in range(1, self.steps + 1): - thres = x_range / self.steps * i - current_loss = self.loss_fn(x, -thres, thres) - x_min = torch.where(current_loss < optimal_loss, -thres, x_min) - x_max = torch.where(current_loss < optimal_loss, thres, x_max) - optimal_loss = torch.min(current_loss, optimal_loss) - - return x_min, x_max - - def forward(self, x_orig): - # since params are static, one calibration is enough - if not self.calibrated: - x = x_orig.detach().to(self.min_val.dtype) - self.min_val, self.max_val = self.line_search(x) - self.calibrated = True - - # return fake-quant result for saturating outliers - scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) - return torch.fake_quantize_per_channel_affine( - x_orig, - scale.data, - zero_point.data.int(), - self.ch_axis, - self.quant_min, - self.quant_max, - ) - - @torch.jit.export - def calculate_qparams(self): - return self._calculate_qparams(self.min_val, self.max_val) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py deleted file mode 100644 index e07ca24d90..0000000000 --- a/backends/qualcomm/quantizer/qconfig.py +++ /dev/null @@ -1,464 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -from torch import Tensor -from torch.ao.quantization.fake_quantize import ( - FakeQuantize, - FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( - MinMaxObserver, - MovingAverageMinMaxObserver, - MovingAveragePerChannelMinMaxObserver, - PerChannelMinMaxObserver, -) -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec -from torch.fx import Node - - -@dataclass(eq=True, frozen=True) -class QuantizationConfig: - input_activation: Optional[QuantizationSpec] - output_activation: Optional[QuantizationSpec] - weight: Optional[QuantizationSpec] - bias: Optional[QuantizationSpec | Callable] - - -def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: - def _derive_bias_qparams_fn( - obs_or_fqs: List, - ) -> Tuple[Tensor, Tensor]: - assert ( - len(obs_or_fqs) == 2 - ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( - act_scale, weight_scale - ) - derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) - return (derived_scale, derived_zero) - - input_act = node.args[0] - assert isinstance(input_act, Node) - weight = node.args[1] - assert isinstance(weight, Node) - - return DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - ch_axis=0, - qscheme=torch.per_channel_symmetric, - ) - - -def get_8a8w_qnn_ptq_config( - act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_16a8w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_16a16w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min + 1, - quant_max=torch.iinfo(torch.int16).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - # torch does not support uint16 quantization, use int32 to bypass - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int8, - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# TODO merge qat and ptq to a fucntion, and use a bool flag to control it -def get_8a8w_qnn_qat_config( - act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver -) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - reduce_range=True, - observer=act_observer, - ) - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_fake_quant_ctr, - ) - - weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - reduce_range=True, - observer=MovingAverageMinMaxObserver, - ) - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=weight_fake_quant_ctr, - ) - - bias_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - reduce_range=True, - observer=MovingAverageMinMaxObserver, - ) - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=bias_fake_quant_ctr, - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_16a4w_qnn_qat_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - reduce_range=True, - observer=act_observer, - ) - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_fake_quant_ctr, - ) - - weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - reduce_range=True, - observer=MovingAverageMinMaxObserver, - ) - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=weight_fake_quant_ctr, - ) - - bias_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - reduce_range=True, - observer=MovingAverageMinMaxObserver, - ) - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=bias_fake_quant_ctr, - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_qat_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int8, - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - reduce_range=True, - observer=act_observer, - ) - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_fake_quant_ctr, - ) - - weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer=MovingAveragePerChannelMinMaxObserver, - ) - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=weight_fake_quant_ctr, - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 50ed07788f..9e5aaf782a 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique -from typing import Callable, Optional, Sequence, Set +from typing import Callable, Dict, Optional, Sequence, Set import torch from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum @@ -22,17 +22,14 @@ from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule -from .annotators import OP_ANNOTATOR - -from .qconfig import ( - get_16a16w_qnn_ptq_config, +from .utils import ( get_16a4w_qnn_ptq_config, - get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, - get_8a8w_qnn_ptq_config, - get_8a8w_qnn_qat_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qat_proto, + get_default_8bit_qnn_ptq_config, get_ptq_per_channel_quant_config, - get_qat_per_channel_quant_config, + OP_ANNOTATOR, QuantizationConfig, ) @@ -41,10 +38,9 @@ "QuantDtype", "get_16a4w_qnn_ptq_config", "get_16a8w_qnn_ptq_config", - "get_16a16w_qnn_ptq_config", - "get_8a8w_qnn_ptq_config", - "get_8a8w_qnn_qat_config", - "get_16a4w_qnn_qat_config", + "get_default_16bit_qnn_ptq_config", + "get_default_8bit_qnn_ptq_config", + "get_default_8bit_qat_proto", ] @@ -55,39 +51,8 @@ class QuantDtype(IntEnum): """ use_16a16w = 0 - use_16a8w = 1 - use_16a4w = 2 - use_8a8w = 3 - - -quant_config_dict = { - # PTQ - (QuantDtype.use_16a16w, False): ( - get_16a16w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, torch.int16), - ), - (QuantDtype.use_16a8w, False): ( - get_16a8w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, torch.int8), - ), - (QuantDtype.use_16a4w, False): ( - get_16a4w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, "int4"), - ), - (QuantDtype.use_8a8w, False): ( - get_8a8w_qnn_ptq_config, - get_ptq_per_channel_quant_config(), - ), - # QAT, - (QuantDtype.use_16a4w, True): ( - get_16a4w_qnn_qat_config, - get_qat_per_channel_quant_config(torch.uint16, "int4"), - ), - (QuantDtype.use_8a8w, True): ( - get_8a8w_qnn_qat_config, - get_qat_per_channel_quant_config(), - ), -} + use_16a4w = 1 + use_8a8w = 2 class QnnQuantizer(Quantizer): @@ -95,17 +60,23 @@ class QnnQuantizer(Quantizer): def __init__(self): super().__init__() - self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() + self.bit8_quant_config: QuantizationConfig = get_default_8bit_qnn_ptq_config() + self.bit16_quant_config: QuantizationConfig = get_default_16bit_qnn_ptq_config() - self.is_qat = False - self.quant_dtype = QuantDtype.use_8a8w - self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() - self.per_channel_quant_config = get_ptq_per_channel_quant_config() - self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() + self.bit8_quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() + self.bit16_quant_ops: Set[OpOverload] = set() self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() + self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() + # the weight quantized for activation 8 bits and 16 bits + self.per_channel_weight_dtype: Dict = { + "8bit_act": torch.int8, + "16bit_act": torch.int16, + } + self.per_channel_quant_config = None + def _annotate(self, gm: GraphModule) -> None: for node in gm.graph.nodes: if node.name in self.discard_nodes: @@ -123,16 +94,29 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig """ Priority: 1. is one of use_per_channel_weight_quant_ops - 2. quant config + 2. int8 / int16 config """ if isinstance(op, str): return if op in self.use_per_channel_weight_quant_ops: + if self.per_channel_quant_config is None: + if op in self.bit16_quant_ops: + return get_ptq_per_channel_quant_config( + act_dtype=torch.uint16, + weight_dtype=self.per_channel_weight_dtype["16bit_act"], + ) + return get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=self.per_channel_weight_dtype["8bit_act"], + ) return self.per_channel_quant_config - if op in self.quant_ops: - return self.quant_config + if op in self.bit8_quant_ops: + return self.bit8_quant_config + + if op in self.bit16_quant_ops: + return self.bit16_quant_config print(f"No quant config is implemented for op, {op}") @@ -142,6 +126,15 @@ def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: boo else: self.use_per_channel_weight_quant_ops.difference_update(ops) + def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None: + for op in ops: + assert ( + op in self.SUPPORTED_OPS + ), f"The annotation of op {op} is not implemented" + + self.bit8_quant_ops.remove(op) + self.bit16_quant_ops.add(op) + def add_custom_quant_annotations( self, custom_quant_annotations: Sequence[Callable] ) -> None: @@ -152,7 +145,10 @@ def add_discard_nodes(self, nodes: Sequence[str]) -> None: def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: for op in ops: - self.quant_ops.remove(op) + if op in self.bit8_quant_ops: + self.bit8_quant_ops.remove(op) + if op in self.bit16_quant_ops: + self.bit16_quant_ops.remove(op) def annotate(self, model: GraphModule) -> GraphModule: self._annotate(model) @@ -163,22 +159,24 @@ def annotate(self, model: GraphModule) -> GraphModule: def get_supported_ops(self) -> Set[OpOverload]: return self.SUPPORTED_OPS - def set_quant_config( - self, quant_dtype: QuantDtype, is_qat=False, act_observer=None + def set_bit16_op_quant_config( + self, quantization_config: QuantizationConfig + ) -> None: + self.bit16_quant_config = quantization_config + + def set_bit8_op_quant_config(self, quantization_config: QuantizationConfig) -> None: + self.bit8_quant_config = quantization_config + + def set_per_channel_weight_dtype( + self, + weight_dtype_for_8bit_act: Optional[str | torch.dtype] = None, + weight_dtype_for_16bit_act: Optional[str | torch.dtype] = None, ) -> None: - self.quant_dtype = quant_dtype - self.is_qat = is_qat - if (quant_dtype, is_qat) not in quant_config_dict: - raise RuntimeError( - f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" - ) - - quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ - (quant_dtype, is_qat) - ] - self.quant_config = ( - quant_config_fuc(act_observer) if act_observer else quant_config_fuc() - ) + # TODO accept temporally str type. Remove it when torch support torch.int4 dtype + if weight_dtype_for_8bit_act: + self.per_channel_weight_dtype["8bit_act"] = weight_dtype_for_8bit_act + if weight_dtype_for_16bit_act: + self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act def set_per_channel_conv_quant(self, enable: bool) -> None: conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/utils.py similarity index 68% rename from backends/qualcomm/quantizer/annotators.py rename to backends/qualcomm/quantizer/utils.py index 275da567e8..dc3d2a6841 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/utils.py @@ -5,16 +5,29 @@ # LICENSE file in the root directory of this source tree. import numbers import operator +from dataclasses import dataclass from functools import partial -from typing import Callable, Dict, List, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch -from torch._ops import OpOverload +from torch import Tensor +from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.observer import FixedQParamsObserver +from torch.ao.quantization.fake_quantize import ( + default_fake_quant, + FusedMovingAvgObsFakeQuantize, +) + +from torch.ao.quantization.observer import ( + FixedQParamsObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + PerChannelMinMaxObserver, + UniformQuantizationObserverBase, +) + from torch.ao.quantization.quantizer import ( DerivedQuantizationSpec, QuantizationAnnotation, @@ -27,12 +40,397 @@ ) from torch.fx import Node -from .qconfig import ( - get_16a16w_qnn_ptq_config, - get_16a4w_qnn_qat_config, - get_8a8w_qnn_qat_config, - QuantizationConfig, -) + +class ParamObserver(UniformQuantizationObserverBase): + def __init__( + self, + ch_axis=0, + use_mse=True, + steps=100, + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.ch_axis = ch_axis + self.use_mse = use_mse + self.steps = steps + self.calibrated = False + + def to_ch_axis(self, x): + axis_order = list(range(len(x.size()))) + axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis + return torch.flatten(x.permute(axis_order), start_dim=1) + + def mse(self, pred, expect): + loss = (pred - expect).abs().pow(2) + return self.to_ch_axis(loss).mean(1) + + def cosine(self, pred, expect): + target = torch.ones(pred.shape[self.ch_axis]) + pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) + expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) + return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) + + def loss_fn(self, x, new_min, new_max): + scale, offset = self._calculate_qparams(new_min, new_max) + x_q = torch.fake_quantize_per_channel_affine( + x, + scale.data, + offset.data.int(), + self.ch_axis, + self.quant_min, + self.quant_max, + ) + return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) + + def line_search(self, x): + x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) + x_range = torch.max(x_min.abs(), x_max) + optimal_loss = torch.zeros_like(x_min) + 1e9 + + # check which clip range could produce smallest loss + for i in range(1, self.steps + 1): + thres = x_range / self.steps * i + current_loss = self.loss_fn(x, -thres, thres) + x_min = torch.where(current_loss < optimal_loss, -thres, x_min) + x_max = torch.where(current_loss < optimal_loss, thres, x_max) + optimal_loss = torch.min(current_loss, optimal_loss) + + return x_min, x_max + + def forward(self, x_orig): + # since params are static, one calibration is enough + if not self.calibrated: + x = x_orig.detach().to(self.min_val.dtype) + self.min_val, self.max_val = self.line_search(x) + self.calibrated = True + + # return fake-quant result for saturating outliers + scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) + return torch.fake_quantize_per_channel_affine( + x_orig, + scale.data, + zero_point.data.int(), + self.ch_axis, + self.quant_min, + self.quant_max, + ) + + @torch.jit.export + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) + + +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec | Callable] + + +def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: + def _derive_bias_qparams_fn( + obs_or_fqs: List, + ) -> Tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( + act_scale, weight_scale + ) + derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) + derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) + return (derived_scale, derived_zero) + + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + return DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + ch_axis=0, + qscheme=torch.per_channel_symmetric, + ) + + +def get_default_8bit_qat_proto(act_symmetric: bool = False) -> QuantizationConfig: + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=default_fake_quant, + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver + ), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=default_fake_quant, + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_default_8bit_qnn_ptq_config( + act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# 4 bits quantization only supports specific ops. +def get_16a4w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a8w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_default_16bit_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min + 1, + quant_max=torch.iinfo(torch.int16).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + # torch does not support uint16 quantization, use int32 to bypass + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, weight_dtype=torch.int8 +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config QUANT_ANNOTATION_KEY = "quantization_annotation" @@ -503,34 +901,19 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non scale = 1 / (q_max - q_min + 1) - bias_obs_ctr = observer = FixedQParamsObserver.with_args( - scale=scale, - zero_point=0, + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( dtype=quantization_config.output_activation.dtype, - qscheme=torch.torch.per_tensor_affine, quant_max=q_max, quant_min=q_min, - ) - if quantization_config in ( - get_8a8w_qnn_qat_config(), - get_16a4w_qnn_qat_config(), - ): - bias_obs_ctr = FixedQParamsFakeQuantize.with_args( - observer=observer, + observer_or_fake_quant_ctr=FixedQParamsObserver.with_args( scale=scale, zero_point=0, dtype=quantization_config.output_activation.dtype, qscheme=torch.torch.per_tensor_affine, quant_max=q_max, quant_min=q_min, - ) - - # make sigmoid map to the range between 0~1 - out_act_quantization_spec = QuantizationSpec( - dtype=quantization_config.output_activation.dtype, - quant_max=q_max, - quant_min=q_min, - observer_or_fake_quant_ctr=bias_obs_ctr, + ), qscheme=torch.torch.per_tensor_affine, ) @@ -703,7 +1086,7 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 - input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -732,7 +1115,7 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 - input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -875,7 +1258,7 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> _annotate_input_qspec_map( node, weight_node, - get_16a16w_qnn_ptq_config().weight, + get_default_16bit_qnn_ptq_config().weight, ) else: _annotate_input_qspec_map( diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 64b0490d46..4bfdedcd4b 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -698,17 +698,6 @@ def test_qnn_backend_16a4w_conv2d(self): ) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_16a4w_conv2d_qat(self): - modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405 - sample_input = (torch.randn([1, 1, 3, 3]),) - for i, module in enumerate(modules): - with self.subTest(i=i): - prepared = self.get_prepared_qat_module(module, sample_input) - converted = self.get_converted_sgd_trained_module( - module, prepared, sample_input - ) - self.lower_module_and_test_output(converted, sample_input) - def test_qnn_backend_16a4w_layer_norm(self): module = LayerNorm() # noqa: F405 sample_input = (torch.randn(196, 768),) @@ -1074,8 +1063,18 @@ def test_qnn_backend_linear_qat(self): """ module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),) - prepared = self.get_prepared_qat_module(module, sample_input) - module = self.get_converted_sgd_trained_module(module, prepared, sample_input) + + module = self.get_prepared_qat_module(module, sample_input) + + optimizer = torch.optim.SGD(module.parameters(), lr=0.1) + criterion = torch.nn.CrossEntropyLoss() + output = module(*sample_input) + loss = criterion(output, module(*sample_input)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + module = torch.ao.quantization.quantize_pt2e.convert_pt2e(module) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_log_softmax(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index d2a3e7c241..114493c7d2 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -17,7 +17,13 @@ from executorch import exir from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.qnn_preprocess import QnnBackend -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qat_proto, + QnnQuantizer, + QuantDtype, +) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -399,7 +405,18 @@ def get_qdq_module( quantizer.add_custom_quant_annotations(custom_quant_annotations) quantizer.set_per_channel_conv_quant(is_conv_per_channel) quantizer.set_per_channel_linear_quant(is_linear_per_channel) - quantizer.set_quant_config(quant_dtype) + + if quant_dtype == QuantDtype.use_8a8w: + pass # default setting + elif quant_dtype == QuantDtype.use_16a16w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") prepared = prepare_pt2e(m, quantizer) prepared(*inputs) @@ -431,28 +448,13 @@ def get_prepared_qat_module( quantizer.set_per_channel_linear_quant(is_linear_per_channel) if quant_dtype == QuantDtype.use_8a8w: - quantizer.set_quant_config(quant_dtype, is_qat=True) + quantizer.set_bit8_op_quant_config(get_default_8bit_qat_proto()) else: raise RuntimeError("Shuld not be here") prepared = prepare_qat_pt2e(m, quantizer) return torch.ao.quantization.move_exported_model_to_train(prepared) - def get_converted_sgd_trained_module( - self, - ori_module: torch.nn.Module, - prepared: torch.nn.Module, - inputs: Tuple[torch.Tensor], - ) -> torch.fx.GraphModule: - optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) - criterion = torch.nn.CrossEntropyLoss() - output = prepared(*inputs) - loss = criterion(output, ori_module(*inputs)) - optimizer.zero_grad() - loss.backward() - optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) - def split_graph(self, graph_module: torch.fx.GraphModule, division: int): class SplitGraph(ExportPass): """ diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index cb54412add..0ea4512abc 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -331,7 +331,7 @@ def _transform( def capture_program( module: torch.nn.Module, inputs: Tuple[torch.Tensor], - custom_pass_config: Set[str] = frozenset(), + custom_pass_config: Set[str] = None, ) -> exir.ExirExportedProgram: ep = torch.export.export(module, inputs) decomposed_ep = ep.run_decompositions(get_decomp_table()) diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index 0e2c695ab3..30fe74f35b 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -10,19 +10,15 @@ import numpy as np import torch -from executorch.backends.qualcomm.quantizer.annotators import ( - QuantizationConfig, - QuantizationSpec, -) -from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( - PerChannelParamObserver, -) -from executorch.backends.qualcomm.quantizer.qconfig import ( + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.quantizer.utils import ( _derived_bias_quant_spec, MovingAverageMinMaxObserver, + ParamObserver, + QuantizationConfig, + QuantizationSpec, ) - -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_EXPAND_BROADCAST_SHAPE, ) @@ -91,7 +87,7 @@ def main(args): quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_channel_symmetric, ch_axis=0, - observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( + observer_or_fake_quant_ctr=ParamObserver.with_args( **{"steps": 200, "use_mse": True} ), ) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 9f7198a344..04569df5c9 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -56,12 +56,12 @@ def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: This function is specific for matmul op 16a8w. """ - from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, - get_8a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, QuantizationConfig, ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, @@ -119,7 +119,7 @@ def annotate_single_in_single_out( ) def annotate_matmul_input1(node: Node): - quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) while isinstance(node, Node) and node.op == "call_function": if node.target in [ torch.ops.aten.permute.default, @@ -142,11 +142,11 @@ def annotate_matmul_input1(node: Node): def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: - from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_ptq_per_channel_quant_config, QuantizationConfig, ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import Node diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index 56169e39a2..2e49a2344b 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -4,7 +4,10 @@ import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer +from executorch.backends.qualcomm.quantizer.quantizer import ( + get_default_8bit_qnn_ptq_config, + QnnQuantizer, +) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -61,6 +64,8 @@ def main() -> None: # Get quantizer quantizer = QnnQuantizer() + quant_config = get_default_8bit_qnn_ptq_config() + quantizer.set_bit8_op_quant_config(quant_config) # Typical pytorch 2.0 quantization flow m = torch.export.export(model.eval(), example_inputs).module() diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 100008e91c..06225be2d1 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -16,7 +16,13 @@ import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + QnnQuantizer, + QuantDtype, +) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -31,11 +37,7 @@ from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torch.ao.quantization.observer import MovingAverageMinMaxObserver -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e class SimpleADB: @@ -185,58 +187,36 @@ def pull_debug_output(self, etdump_path, debug_ouput_path, callback=None): callback() -def ptq_calibrate(captured_model, quantizer, dataset): - annotated_model = prepare_pt2e(captured_model, quantizer) - print("Quantizing(PTQ) the model...") - # calibration - if callable(dataset): - dataset(annotated_model) - else: - for data in dataset: - annotated_model(*data) - return annotated_model - - -def qat_train(ori_model, captured_model, quantizer, dataset): - data, targets = dataset - annotated_model = torch.ao.quantization.move_exported_model_to_train( - prepare_qat_pt2e(captured_model, quantizer) - ) - optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) - criterion = torch.nn.CrossEntropyLoss() - for i, d in enumerate(data): - print(f"Epoch {i}") - if i > 3: - # Freeze quantizer parameters - annotated_model.apply(torch.ao.quantization.disable_observer) - if i > 2: - # Freeze batch norm mean and variance estimates - annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) - - output = annotated_model(*d) - loss = criterion(output, targets[i]) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - return torch.ao.quantization.quantize_pt2e.convert_pt2e( - torch.ao.quantization.move_exported_model_to_eval(annotated_model) - ) - - def make_quantizer( - quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, + quant_dtype: Optional[QuantDtype], custom_annotations=(), per_channel_conv=True, per_channel_linear=False, act_observer=MovingAverageMinMaxObserver, - is_qat=False, ): quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) quantizer.set_per_channel_conv_quant(per_channel_conv) quantizer.set_per_channel_linear_quant(per_channel_linear) - quantizer.set_quant_config(quant_dtype, is_qat, act_observer) + + if quant_dtype == QuantDtype.use_8a8w: + quantizer.set_bit8_op_quant_config( + get_default_8bit_qnn_ptq_config(act_observer=act_observer) + ) + elif quant_dtype == QuantDtype.use_16a16w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_default_16bit_qnn_ptq_config(act_observer=act_observer) + ) + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=act_observer) + ) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + return quantizer @@ -255,22 +235,18 @@ def build_executorch_binary( metadata=None, dump_intermediate_outputs=False, custom_pass_config=frozenset(), - qat_training_data=None, ): if quant_dtype is not None: + quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) captured_model = torch.export.export(model, inputs).module() - if qat_training_data: - quantizer = custom_quantizer or make_quantizer( - quant_dtype=quant_dtype, is_qat=True - ) - # qat training - annotated_model = qat_train( - model, captured_model, quantizer, qat_training_data - ) + annotated_model = prepare_pt2e(captured_model, quantizer) + print("Quantizing the model...") + # calibration + if callable(dataset): + dataset(annotated_model) else: - quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) - # ptq calibration - annotated_model = ptq_calibrate(captured_model, quantizer, dataset) + for data in dataset: + annotated_model(*data) quantized_model = convert_pt2e(annotated_model) edge_prog = capture_program(quantized_model, inputs, custom_pass_config) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index ba281864a9..fd368d73f1 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -144,7 +144,6 @@ def check_embedding_byte_registered(): def get_qnn_quantizer( pt2e_quantize: str, quantization_mode: Optional[str] = None, - is_qat: bool = False, ): try: from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21] @@ -153,6 +152,8 @@ def get_qnn_quantizer( # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, QnnQuantizer, QuantDtype, ) @@ -174,7 +175,6 @@ def get_qnn_quantizer( custom_annotations = () if quant_config == "8a8w": quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16] - qnn_quantizer.set_quant_config(quant_dtype, is_qat=is_qat) elif quant_config == "16a16w": quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w @@ -184,17 +184,20 @@ def get_qnn_quantizer( ) qnn_quantizer.set_per_channel_conv_quant(enable=False) qnn_quantizer.set_per_channel_linear_quant(enable=False) - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - qnn_quantizer.set_quant_config( - quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver + qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) + qnn_quantizer.set_bit16_op_quant_config( + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + get_default_16bit_qnn_ptq_config(act_observer=MinMaxObserver) ) elif quant_config == "16a4w": # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. quant_dtype = QuantDtype.use_16a4w - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - qnn_quantizer.set_quant_config( - quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver + qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) + qnn_quantizer.set_bit16_op_quant_config( + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) ) + qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. custom_annotations = (custom_annotate_llama_matmul_16a8w,) else: From cb2a0e71769c7b3428077204a3c8eddfa317f00e Mon Sep 17 00:00:00 2001 From: winskuo-quic <143469905+winskuo-quic@users.noreply.github.com> Date: Fri, 8 Nov 2024 06:39:23 +0800 Subject: [PATCH 30/59] Qualcomm AI Engine Direct - Reduce redundant observers (#6351) --- backends/qualcomm/quantizer/utils.py | 32 +++++++++++++++++++++------- examples/qualcomm/utils.py | 6 ++++-- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index dc3d2a6841..223b068375 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -229,14 +229,29 @@ def get_default_8bit_qnn_ptq_config( ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + else: + # PyTorch will remove redundant observers based on attributes such as: + # dtype, quant_min, quant_max, ch_axis, etc. + # Providing values like quant_min and quant_max can help observers compare + # and further reduce the number of observers. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, @@ -409,6 +424,7 @@ def get_ptq_per_channel_quant_config( quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, + ch_axis=0, observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), ) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 06225be2d1..ae5444023a 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -348,7 +348,9 @@ def histogram(golden, predict): return (pa, mpa, miou, cls_iou) -def get_imagenet_dataset(dataset_path, data_size, image_shape, crop_size=None): +def get_imagenet_dataset( + dataset_path, data_size, image_shape, crop_size=None, shuffle=True +): from torchvision import datasets, transforms def get_data_loader(): @@ -365,7 +367,7 @@ def get_data_loader(): imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) return torch.utils.data.DataLoader( imagenet_data, - shuffle=True, + shuffle=shuffle, ) # prepare input data From 39e5b91c625ee75b6b77e424fe7a929cb4596915 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 7 Nov 2024 14:58:14 -0800 Subject: [PATCH 31/59] [ET-VK][ez] properly parse skip memory metadata pass (#6723) Pull Request resolved: https://github.com/pytorch/executorch/pull/6712 As title. Currently, the compile option to skip memory metadata tagging is not being passed correctly to `vulkan_preprocess`. ghstack-source-id: 252359943 @exported-using-ghexport Differential Revision: [D65600049](https://our.internmc.facebook.com/intern/diff/D65600049/) Co-authored-by: Stephen Jia --- backends/vulkan/partitioner/vulkan_partitioner.py | 4 ++++ backends/vulkan/vulkan_preprocess.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index f1fd47fb2b..7b2ad3fdfd 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -252,6 +252,10 @@ def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: value_bytes = int(value).to_bytes(4, byteorder="little") compile_specs.append(CompileSpec(key, value_bytes)) + if isinstance(value, bool): + value_bytes = value.to_bytes(1, byteorder="little") + compile_specs.append(CompileSpec(key, value_bytes)) + if key == "texture_limits": compile_specs.append( CompileSpec( diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index f0a5fd6725..c938f9ff42 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -98,6 +98,10 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: ) if spec.key in {"texture_limits_x", "texture_limits_y", "texture_limits_z"}: options[spec.key] = int.from_bytes(spec.value, byteorder="little") + + if spec.key == "skip_tag_memory_metadata": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options From 485a5dfcf428770098b962e8bcbaacd10290cab3 Mon Sep 17 00:00:00 2001 From: Fredrik Knutsson Date: Fri, 8 Nov 2024 16:57:57 +0100 Subject: [PATCH 32/59] Revert "Run tosa_reference_model using python binding" (#6729) Revert "Run tosa_reference_model using python binding (#6658)" This reverts commit 4bbe9945b7c221f7b687dbb6754ce4e650c93c05. --- backends/arm/arm_backend.py | 9 ++- backends/arm/test/common.py | 12 ++-- backends/arm/test/misc/test_debug_feats.py | 5 +- backends/arm/test/ops/test_cat.py | 2 +- backends/arm/test/ops/test_select.py | 4 +- backends/arm/test/runner_utils.py | 81 ++++------------------ backends/arm/test/tester/arm_tester.py | 15 ++-- examples/arm/setup.sh | 27 ++++++-- 8 files changed, 59 insertions(+), 96 deletions(-) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index db3b368115..28af583106 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -13,7 +13,7 @@ import logging import os -from typing import cast, final, List, Optional +from typing import final, List, Optional import serializer.tosa_serializer as ts from executorch.backends.arm.arm_vela import vela_compile @@ -31,7 +31,6 @@ from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram -from torch.fx import Node # TOSA backend debug functionality logger = logging.getLogger(__name__) @@ -226,7 +225,6 @@ def preprocess( # noqa: C901 node_visitors = get_node_visitors(edge_program) for node in graph_module.graph.nodes: - node = cast(Node, node) if node.op == "call_function": process_call_function(node, tosa_graph, node_visitors) elif node.op == "placeholder": @@ -238,6 +236,9 @@ def preprocess( # noqa: C901 # any checking of compatibility. dbg_fail(node, tosa_graph, artifact_path) + # TODO: It would be awesome if this dump could somehow be done on top level and not here. + # Problem is that the desc.json has to be created on the tosa_graph object, which we can't + # access from top level. if artifact_path: tag = _get_first_delegation_tag(graph_module) dbg_tosa_dump( @@ -258,4 +259,6 @@ def preprocess( # noqa: C901 else: raise RuntimeError(f"Unknown format {output_format}") + # Continueing from above. Can I put tosa_graph into this function? + # debug_handle_map = ... return PreprocessResult(processed_bytes=binary) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 1a155c0323..b0e2a7f0bb 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -192,15 +192,19 @@ def get_tosa_compile_spec_unbuilt( the compile spec before calling .build() to finalize it. """ if not custom_path: - custom_path = maybe_get_tosa_collate_path() + intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp( + prefix="arm_tosa_" + ) + else: + intermediate_path = custom_path - if custom_path is not None and not os.path.exists(custom_path): - os.makedirs(custom_path, exist_ok=True) + if not os.path.exists(intermediate_path): + os.makedirs(intermediate_path, exist_ok=True) compile_spec_builder = ( ArmCompileSpecBuilder() .tosa_compile_spec() .set_permute_memory_format(permute_memory_to_nhwc) - .dump_intermediate_artifacts_to(custom_path) + .dump_intermediate_artifacts_to(intermediate_path) ) return compile_spec_builder diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 1aa3e82c76..7d9a18a80e 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -107,10 +107,7 @@ def test_numerical_diff_prints(self): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=True, - custom_path=tempfile.mkdtemp("diff_print_test"), - ), + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), ) .export() .to_edge() diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index b0a38ce198..9723ba0f0c 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int): def test_cat_4d_tosa_MI(self): square = torch.ones((2, 2, 2, 2)) for dim in range(-3, 3): - test_data = ((square, square.clone()), dim) + test_data = ((square, square), dim) self._test_cat_tosa_MI_pipeline(self.Cat(), test_data) @parameterized.expand(Cat.test_parameters) diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py index 6a47c2e66b..fdb2fa1463 100644 --- a/backends/arm/test/ops/test_select.py +++ b/backends/arm/test/ops/test_select.py @@ -93,6 +93,8 @@ def _test_select_tosa_BI_pipeline( .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() + .dump_artifact() + .dump_operator_distribution() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) @@ -160,14 +162,12 @@ def test_select_int_tosa_MI(self, test_data: test_data_t): ) @parameterized.expand(test_data_suite) - @unittest.skip def test_select_copy_tosa_BI(self, test_data: test_data_t): self._test_select_tosa_BI_pipeline( self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" ) @parameterized.expand(test_data_suite) - @unittest.skip def test_select_int_tosa_BI(self, test_data: test_data_t): self._test_select_tosa_BI_pipeline( self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index f3c90eda83..d2ee113a5d 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -17,14 +17,11 @@ import numpy as np import torch -import tosa_reference_model - from torch.export import ExportedProgram from torch.fx.node import Node -from tosa import TosaGraph logger = logging.getLogger(__name__) -logger.setLevel(logging.CRITICAL) +logger.setLevel(logging.WARNING) class QuantizationParams: @@ -170,7 +167,7 @@ def __init__( ): self.intermediate_path = intermediate_path self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model" - assert self.intermediate_path is None or os.path.exists( + assert os.path.exists( self.intermediate_path ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}" @@ -326,46 +323,7 @@ def run_corstone( tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) output_shape = self.output_node.args[0][0].meta["val"].shape tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape) - return tosa_ref_output - - def run_tosa_graph( - self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor] - ) -> torch.Tensor: - """Runs the TOSA reference model with inputs and returns the result.""" - data_np = [ - prep_data_for_save( - input, self.is_quantized, self.input_names[i], self.qp_input[i] - ) - for i, input in enumerate(inputs) - ] - # tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training. - tosa_profile = 0 if self.is_quantized else 1 - debug_mode = "ALL" if logger.level <= logging.DEBUG else None - outputs, status = tosa_reference_model.run( - graph, - data_np, - verbosity=_tosa_refmodel_loglevel(logger.level), - tosa_profile=tosa_profile, - initialize_variable_tensor_from_numpy=1, # True - debug_mode=debug_mode, - ) - - assert ( - status == tosa_reference_model.GraphStatus.TOSA_VALID - ), "Non-valid TOSA given to reference model." - - outputs_torch = [] - for output in outputs: - output = output.astype(np.float32) - if self.is_quantized: - # Need to dequant back to FP32 for comparison with torch output - quant_param = self.qp_output - assert ( - quant_param is not None - ), "There are no quantization parameters, check output parameters" - output = (output - quant_param.zp) * quant_param.scale - outputs_torch.append(torch.from_numpy(output)) - return tuple(outputs_torch) + return [tosa_ref_output] def run_tosa_ref_model( self, @@ -450,13 +408,21 @@ def run_tosa_ref_model( assert ( shutil.which(self.tosa_ref_model_path) is not None ), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}" - + loglevel_map = { + logging.INFO: "INFO", + logging.CRITICAL: "LOW", + logging.ERROR: "LOW", + logging.WARNING: "MED", + logging.DEBUG: "HIGH", + logging.NOTSET: "MED", + } + clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0) cmd_ref_model = [ self.tosa_ref_model_path, "--test_desc", desc_file_path, "-l", - _tosa_refmodel_loglevel(logger.level), + loglevel_map[clamped_logging_level], ] _run_cmd(cmd_ref_model) @@ -492,10 +458,7 @@ def run_tosa_ref_model( def prep_data_for_save( - data: torch.Tensor, - is_quantized: bool, - input_name: str, - quant_param: QuantizationParams, + data, is_quantized: bool, input_name: str, quant_param: QuantizationParams ): data_np = np.array(data.detach(), order="C").astype( f"{data.dtype}".replace("torch.", "") @@ -639,19 +602,3 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: pass return json_out - - -def _tosa_refmodel_loglevel(loglevel: int) -> str: - """Converts a logging loglevel to tosa_reference_model logginglevel, - returned as string. - """ - loglevel_map = { - logging.INFO: "INFO", - logging.CRITICAL: "LOW", - logging.ERROR: "LOW", - logging.WARNING: "MED", - logging.DEBUG: "HIGH", - logging.NOTSET: "MED", - } - clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0) - return loglevel_map[clamped_logging_level] diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 834e177b7d..096bc2b22f 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -39,7 +39,7 @@ from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir import EdgeCompileConfig from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule @@ -120,15 +120,10 @@ def __init__( super().__init__(dynamic_shapes) self.tosa_test_util = tosa_test_util - def run(self, artifact: EdgeProgramManager, inputs=None): - self.executorch_program = artifact.to_executorch(self.config) - if module := getattr( - artifact.exported_program().graph_module, "lowered_module_0", None - ): - self.buffer = module.processed_bytes - def run_artifact(self, inputs): - tosa_output = self.tosa_test_util.run_tosa_graph(self.buffer, inputs) + tosa_output = self.tosa_test_util.run_tosa_ref_model( + inputs=inputs, + ) return tosa_output @@ -321,7 +316,7 @@ def run_method_and_compare_outputs( logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}") reference_output = reference_stage.run_artifact(reference_input) - test_output = test_stage.run_artifact(test_input) + test_output = tuple(test_stage.run_artifact(test_input)) if ( is_nhwc and test_stage == self.stages[self.stage_name(tester.ToExecutorch)] diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 43f7d48b83..583237729d 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -88,7 +88,7 @@ ethos_u_base_rev="24.08" # tosa reference model tosa_reference_model_url="https://review.mlplatform.org/tosa/reference_model" -tosa_reference_model_rev="ef31e7222e99cb1c24b2aff9fc52b2d609612283" +tosa_reference_model_rev="f9ea4ab7da19318fe36b1c34d68a3e40fd6e56c5" ######## ### Mandatory user args @@ -227,13 +227,30 @@ function setup_tosa_reference_model() { cd reference_model git checkout ${tosa_reference_model_rev} git submodule update --init --recursive + cd .. + fi + cd reference_model + mkdir -p build + cd build + cmake .. + + # make use of half the cores for building + if [[ "${OS}" == "Linux" ]]; then + n=$(( $(nproc) / 2 )) + elif [[ "${OS}" == "Darwin" ]]; then + n=$(( $(sysctl -n hw.logicalcpu) / 2 )) + else + n=1 fi - echo "pip installing reference_model..." - repo_dir="${root_dir}/reference_model" - cd $repo_dir - pip install . + if [[ "$n" -lt 1 ]]; then + n=1 + fi + make -j"${n}" + cd reference_model + tosa_bin_path=`pwd` + echo "export PATH=\${PATH}:${tosa_bin_path}" >> "${setup_path_script}" } function setup_vela() { From b0f9a61767feac43d55da9f4ece3f0acb95cdeac Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Fri, 8 Nov 2024 09:27:32 -0800 Subject: [PATCH 33/59] Add support for uint16 in quant and dequant kernels Differential Revision: D65370235 Pull Request resolved: https://github.com/pytorch/executorch/pull/6724 --- backends/cadence/hifi/kernels/kernels.cpp | 4 ++++ backends/cadence/hifi/operators/dequantize_per_tensor.cpp | 3 +++ backends/cadence/hifi/operators/quantize_per_tensor.cpp | 4 ++++ backends/cadence/reference/kernels/kernels.cpp | 4 ++++ .../cadence/reference/operators/dequantize_per_tensor.cpp | 8 ++++++++ .../cadence/reference/operators/quantize_per_tensor.cpp | 8 ++++++++ 6 files changed, 31 insertions(+) diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index 10e5fb176e..1b335c846b 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -165,6 +165,7 @@ void requantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); #undef typed_quantize_val #define typed_quantize_vec(dtype) \ @@ -177,6 +178,7 @@ typed_quantize_val(int16_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -186,6 +188,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); #undef typed_dequantize_val #define typed_dequantize_vec(dtype) \ @@ -198,6 +201,7 @@ typed_dequantize_val(int16_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index 18381a26e0..996d753c59 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -41,6 +41,9 @@ void dequantize_per_tensor_out( } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index c65d62968f..1078b5716c 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -44,6 +44,10 @@ void quantize_per_tensor_out( int16_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + cadence::impl::HiFi::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( diff --git a/backends/cadence/reference/kernels/kernels.cpp b/backends/cadence/reference/kernels/kernels.cpp index 4d4ff26c3f..faac3d7cb2 100644 --- a/backends/cadence/reference/kernels/kernels.cpp +++ b/backends/cadence/reference/kernels/kernels.cpp @@ -65,6 +65,7 @@ void dequantize( typed_quantize_val(int8_t); typed_quantize_val(uint8_t); typed_quantize_val(int16_t); +typed_quantize_val(uint16_t); typed_quantize_val(int32_t); #undef typed_quantize_val @@ -78,6 +79,7 @@ typed_quantize_val(int32_t); typed_quantize_vec(int8_t); typed_quantize_vec(uint8_t); typed_quantize_vec(int16_t); +typed_quantize_vec(uint16_t); typed_quantize_vec(int32_t); #undef typed_quantize_vec @@ -86,6 +88,7 @@ typed_quantize_vec(int32_t); typed_dequantize_val(int8_t); typed_dequantize_val(uint8_t); typed_dequantize_val(int16_t); +typed_dequantize_val(uint16_t); typed_dequantize_val(int32_t); #undef typed_dequantize_val @@ -99,6 +102,7 @@ typed_dequantize_val(int32_t); typed_dequantize_vec(int8_t); typed_dequantize_vec(uint8_t); typed_dequantize_vec(int16_t); +typed_dequantize_vec(uint16_t); typed_dequantize_vec(int32_t); #undef typed_dequantize_vec diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp index aef730bfd1..b49c045b94 100644 --- a/backends/cadence/reference/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/dequantize_per_tensor.cpp @@ -37,6 +37,14 @@ void dequantize_per_tensor_out( const int8_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Bits16) { + const uint16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + impl::reference::kernels::dequantize( + out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp index 0d7ff0bc7e..ad5fa791b5 100644 --- a/backends/cadence/reference/operators/quantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/quantize_per_tensor.cpp @@ -39,6 +39,14 @@ void quantize_per_tensor_out( int8_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Bits16) { + uint16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + impl::reference::kernels::quantize( + out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( From 1cd8a06859470100d4e404f38b57e0ff2bde963a Mon Sep 17 00:00:00 2001 From: JP <46308822+zonglinpeng@users.noreply.github.com> Date: Fri, 8 Nov 2024 10:55:33 -0800 Subject: [PATCH 34/59] migrate passes and utils in cadence backend Differential Revision: D65447532 Pull Request resolved: https://github.com/pytorch/executorch/pull/6647 --- backends/cadence/aot/TARGETS | 15 +++++ backends/cadence/aot/pass_utils.py | 91 ++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 backends/cadence/aot/pass_utils.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 8456c50f6c..9876e59dbf 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -62,6 +62,21 @@ python_library( ], ) +python_library( + name = "pass_utils", + srcs = [ + "pass_utils.py", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/exir/passes:lib", + "//executorch/exir/passes:spec_prop_pass", + ], +) + python_library( name = "ops_registrations", srcs = [ diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py new file mode 100644 index 0000000000..3aa6f48a31 --- /dev/null +++ b/backends/cadence/aot/pass_utils.py @@ -0,0 +1,91 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from dataclasses import dataclass +from typing import Callable, Optional, Set, Union + +import torch +from executorch.backends.cadence.aot.utils import get_edge_overload_packet + +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket + +from executorch.exir.pass_base import ExportPass +from torch._ops import OpOverloadPacket + + +# Is an overlap in tensor lifetime and storage allowed at the current opt level? +# We allow overlap at opt level >= 2. +def allow_lifetime_and_storage_overlap(opt_level: int) -> bool: + return opt_level >= 2 + + +# A dataclass that stores the attributes of an ExportPass. +@dataclass +class CadencePassAttribute: + opt_level: Optional[int] = None + debug_pass: bool = False + + +# A dictionary that maps an ExportPass to its attributes. +_ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} + + +def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute: + return _ALL_CADENCE_PASSES[p] + + +# A decorator that registers a pass. +def register_cadence_pass( + pass_attribute: CadencePassAttribute, +) -> Callable[[ExportPass], ExportPass]: + def wrapper(cls: ExportPass) -> ExportPass: + _ALL_CADENCE_PASSES[cls] = pass_attribute + return cls + + return wrapper + + +def get_all_available_cadence_passes() -> Set[ExportPass]: + return set(_ALL_CADENCE_PASSES.keys()) + + +# Create a new filter to filter out relevant passes from all Jarvis passes. +def create_cadence_pass_filter( + opt_level: int, debug: bool = False +) -> Callable[[ExportPass], bool]: + def _filter(p: ExportPass) -> bool: + pass_attribute = get_cadence_pass_attribute(p) + return ( + pass_attribute.opt_level is not None + and pass_attribute.opt_level <= opt_level + and (not pass_attribute.debug_pass or debug) + ) + + return _filter + + +# Return the overload packet for the edge or torch op. +def get_overload_packet( + op: Union[Callable[..., str], str], +) -> Union[OpOverloadPacket, EdgeOpOverloadPacket, None]: + return ( + get_edge_overload_packet(op) + if isinstance(op, EdgeOpOverload) + else getattr(op, "overloadpacket", None) + ) + + +# Get the list of node names in a graph module (only for "call_function" ops and +# EdgeOpOverload targets). This should be used only after to_edge is called. +def get_node_names_list_from_gm( + graph_module: torch.fx.GraphModule, +) -> list[torch.fx.Node]: + graph_nodes = [] + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if not isinstance(node.target, EdgeOpOverload): + continue + graph_nodes.append(node.name) + return graph_nodes From 6d6630edae0bdc37c8450e55f2fc0ef90cb5f50c Mon Sep 17 00:00:00 2001 From: JP <46308822+zonglinpeng@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:02:57 -0800 Subject: [PATCH 35/59] register quantized_linear.per_tensor in lib Differential Revision: D65104400 Pull Request resolved: https://github.com/pytorch/executorch/pull/6563 --- backends/cadence/aot/ops_registrations.py | 28 ++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index fce6ce5736..5e852b369d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -50,7 +50,11 @@ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( - "cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " + "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" ) lib.define( @@ -129,6 +133,28 @@ def quantized_linear_meta( return src.new_empty(out_size, dtype=src.dtype) +@register_fake("cadence::quantized_linear.per_tensor") +def quantized_linear_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: torch.SymInt, + weight_zero_point: torch.SymInt, + out_multiplier: torch.SymInt, + out_shift: torch.SymInt, + out_zero_point: torch.SymInt, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + @register_fake("cadence::quantized_conv") def quantized_conv_meta( input: torch.Tensor, From ddc8ea6f8daae685bcebbfc0c56736c239878184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Fri, 8 Nov 2024 20:05:59 +0100 Subject: [PATCH 36/59] Tosa specification handling (#6688) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add TOSA specification details to the Arm Backend. * Mandate the need for a TOSA version in the compile spec list passed to the Arm backend and propagate the information to node visitors for serialization handling. * Add TOSA version string to all TOSA tests * Adds handling of TOSA 0.80 BI and MI profile as separate serialization handlers for ADD as an example. Signed-off-by: Per Åstrand --- backends/arm/arm_backend.py | 31 ++- backends/arm/operators/node_visitor.py | 36 ++- backends/arm/operators/op_add.py | 73 +++++- backends/arm/operators/op_placeholder.py | 12 +- backends/arm/test/common.py | 10 +- backends/arm/test/misc/test_debug_feats.py | 16 +- .../arm/test/misc/test_dim_order_guards.py | 4 +- backends/arm/test/misc/test_lifted_tensor.py | 8 +- backends/arm/test/misc/test_tosa_spec.py | 105 ++++++++ .../arm/test/models/test_mobilenet_v2_arm.py | 8 +- backends/arm/test/ops/test_add.py | 4 +- backends/arm/test/ops/test_avg_pool.py | 8 +- backends/arm/test/ops/test_batch_norm.py | 6 +- backends/arm/test/ops/test_bmm.py | 4 +- backends/arm/test/ops/test_cat.py | 4 +- backends/arm/test/ops/test_clone.py | 4 +- backends/arm/test/ops/test_conv1d.py | 8 +- backends/arm/test/ops/test_conv2d.py | 8 +- backends/arm/test/ops/test_conv_combos.py | 8 +- backends/arm/test/ops/test_depthwise_conv.py | 8 +- backends/arm/test/ops/test_div.py | 4 +- backends/arm/test/ops/test_exp.py | 4 +- backends/arm/test/ops/test_expand.py | 4 +- backends/arm/test/ops/test_full.py | 4 +- backends/arm/test/ops/test_hardtanh.py | 4 +- backends/arm/test/ops/test_layer_norm.py | 8 +- backends/arm/test/ops/test_linear.py | 8 +- backends/arm/test/ops/test_log.py | 4 +- backends/arm/test/ops/test_logsoftmax.py | 4 +- backends/arm/test/ops/test_max_pool.py | 8 +- backends/arm/test/ops/test_mean_dim.py | 8 +- backends/arm/test/ops/test_mm.py | 4 +- backends/arm/test/ops/test_mul.py | 8 +- backends/arm/test/ops/test_permute.py | 4 +- backends/arm/test/ops/test_reciprocal.py | 4 +- backends/arm/test/ops/test_relu.py | 4 +- backends/arm/test/ops/test_repeat.py | 4 +- backends/arm/test/ops/test_rsqrt.py | 4 +- backends/arm/test/ops/test_scalars.py | 4 +- backends/arm/test/ops/test_select.py | 4 +- backends/arm/test/ops/test_sigmoid.py | 4 +- backends/arm/test/ops/test_slice.py | 4 +- backends/arm/test/ops/test_softmax.py | 4 +- backends/arm/test/ops/test_split.py | 4 +- backends/arm/test/ops/test_squeeze.py | 4 +- backends/arm/test/ops/test_sub.py | 4 +- backends/arm/test/ops/test_sum.py | 4 +- backends/arm/test/ops/test_tanh.py | 4 +- backends/arm/test/ops/test_unsqueeze.py | 4 +- backends/arm/test/ops/test_var.py | 4 +- backends/arm/test/ops/test_view.py | 4 +- .../passes/test_meandim_to_averagepool2d.py | 4 +- .../test/quantizer/test_generic_annotater.py | 4 +- backends/arm/tosa_specification.py | 226 ++++++++++++++++++ backends/arm/tosa_utils.py | 4 +- examples/arm/aot_arm_compiler.py | 4 +- 56 files changed, 618 insertions(+), 133 deletions(-) create mode 100644 backends/arm/test/misc/test_tosa_spec.py create mode 100644 backends/arm/tosa_specification.py diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 28af583106..b55f237543 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -20,6 +20,8 @@ from executorch.backends.arm.operators.node_visitor import get_node_visitors from executorch.backends.arm.operators.op_output import process_output from executorch.backends.arm.operators.op_placeholder import process_placeholder + +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm._passes.arm_pass_manager import ( ArmPassManager, ) # usort: skip @@ -86,9 +88,15 @@ def ethosu_compile_spec( if extra_flags is not None: self.compiler_flags.append(extra_flags) + base_tosa_version = "TOSA-0.80.0+BI" + if "U55" in config: + # Add the Ethos-U55 extension marker + base_tosa_version += "+u55" + self.tosa_version = TosaSpecification.create_from_string(base_tosa_version) + return self - def tosa_compile_spec(self) -> "ArmCompileSpecBuilder": + def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder": """ Generate compile spec for TOSA flatbuffer output """ @@ -96,6 +104,7 @@ def tosa_compile_spec(self) -> "ArmCompileSpecBuilder": self.output_format is None ), f"Output format already set: {self.output_format}" self.output_format = "tosa" + self.tosa_version = TosaSpecification.create_from_string(tosa_version) return self def dump_intermediate_artifacts_to( @@ -129,6 +138,13 @@ def build(self) -> List[CompileSpec]: """ Generate a list of compile spec objects from the builder """ + assert self.tosa_version + + # Always supply a TOSA version + self.compile_spec = [ + CompileSpec("tosa_version", str(self.tosa_version).encode()) + ] + if self.output_format == "vela": self.compile_spec += [ CompileSpec("output_format", "vela".encode()), @@ -210,11 +226,18 @@ def preprocess( # noqa: C901 if not output_format: raise RuntimeError("output format is required") + tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec) + assert ( + tosa_spec is not None + ), "TOSA backend needs a TOSA version specified in the CompileSpec!" + if output_format == "vela" and len(compile_flags) == 0: # Not testing for compile_flags correctness here, just that they are # present. The compiler will give errors if they are not valid. raise RuntimeError("compile flags are required for vela output format") + logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") + # Converted output for this subgraph, serializer needs path early as it emits # const data directly. Path created and data written only in debug builds. tosa_graph = ts.TosaSerializer(artifact_path) @@ -222,13 +245,13 @@ def preprocess( # noqa: C901 exported_program=edge_program, compile_spec=compile_spec ) - node_visitors = get_node_visitors(edge_program) + node_visitors = get_node_visitors(edge_program, tosa_spec) for node in graph_module.graph.nodes: if node.op == "call_function": - process_call_function(node, tosa_graph, node_visitors) + process_call_function(node, tosa_graph, node_visitors, tosa_spec) elif node.op == "placeholder": - process_placeholder(node, tosa_graph, edge_program) + process_placeholder(node, tosa_graph, edge_program, tosa_spec) elif node.op == "output": process_output(node, tosa_graph) else: diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 99fd0388e4..9e98ebcab9 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -1,4 +1,4 @@ -# Copyright 2023 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ import serializer.tosa_serializer as ts import torch from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from torch.export import ExportedProgram @@ -18,8 +19,19 @@ class NodeVisitor: Node Visitor pattern for lowering edge IR to TOSA """ - def __init__(self, exported_program: ExportedProgram): + # Add the currently supported node_visitor specs as default. + # This should be overriden in the NodeVisitor subclasses to target + # a specific TOSA version. + # When all node_visitors has been refactored to target a specific + # version, this list should be removed. + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification): self._exported_program = exported_program or None + self.tosa_spec = tosa_spec def define_node( self, @@ -33,16 +45,30 @@ def define_node( # container for all node visitors -_node_visitor_dict = {} +_node_visitor_dicts = { + TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {}, + TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {}, +} def register_node_visitor(visitor): - _node_visitor_dict[visitor.target] = visitor + for tosa_spec in visitor.tosa_specs: + _node_visitor_dicts[tosa_spec][visitor.target] = visitor + return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: node_visitors = {} - for target, visitor in _node_visitor_dict.items(): + tosa_spec = None + for arg in args: + if isinstance(arg, TosaSpecification): + tosa_spec = arg + break + + if tosa_spec is None: + raise RuntimeError("No TOSA specification supplied.") + + for target, visitor in _node_visitor_dicts[tosa_spec].items(): node_visitors[target] = visitor(*args) return node_visitors diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index ec2ade9e8a..d0518c7a5e 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -11,19 +11,25 @@ import executorch.backends.arm.tosa_utils as tutils import serializer.tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor -class AddVisitor(NodeVisitor): +class AddVisitor_080_BI(NodeVisitor): target = "aten.add.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + ] + def __init__(self, *args): super().__init__(*args) @@ -35,9 +41,22 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_quant_node: - input_nodes = tutils.get_two_inputs(node) + input_nodes = tutils.get_two_inputs(node) + + if not is_quant_node and not all( + tensor.meta["val"].dtype in (torch.int8, torch.int32) + for tensor in input_nodes + ): + raise RuntimeError( + f"Unexpected non quantized {AddVisitor_080_BI.target} node." + ) + needs_rescale = not ( + all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) + and node.meta["val"].dtype == torch.int32 + ) + + if needs_rescale: # Rescale inputs to 32 bit rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( input_nodes, tosa_graph @@ -48,20 +67,48 @@ def define_node( rescaled_inputs[0].shape, rescaled_inputs[0].shape ) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + add_output = output + rescaled_inputs = inputs - # Do the INT32 Add - tosa_graph.addOperator( - TosaOp.Op().ADD, - [ - rescaled_inputs[0].name, - rescaled_inputs[1].name, - ], - [add_output.name], - None, - ) + # Do the INT32 Add + tosa_graph.addOperator( + TosaOp.Op().ADD, + [ + rescaled_inputs[0].name, + rescaled_inputs[1].name, + ], + [add_output.name], + None, + ) + if needs_rescale: # Scale output back to 8 bit tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) + + +@register_node_visitor +class AddVisitor_080_MI(AddVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + if is_quant_node: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output, is_quant_node) else: # FP32 Add lowering tosa_graph.addOperator( diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 2618c9e71d..f3e52e68f7 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -14,6 +14,7 @@ get_quant_node_args, is_quant_arg, ) +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import ( is_bias_node_for_quantized_addmm, is_bias_node_for_quantized_conv, @@ -26,6 +27,7 @@ def process_inputs( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, + tosa_spec: TosaSpecification, ): """Serialize an input node""" # inputs need to be in default dim_order (contiguous memory format) @@ -95,6 +97,7 @@ def process_inputs_to_parameters( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, edge_program: ExportedProgram, + tosa_spec: TosaSpecification, ): """Serialize bias and non-quantized weights""" inputs = [TosaArg(node)] @@ -106,9 +109,13 @@ def process_inputs_to_parameters( if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node): # BI bias + assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer" process_quantized_bias(node, tosa_graph, parameter_values) else: # MI weights or bias + if inputs[0].dtype == torch.float32: + assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" + parameter_values = np.transpose(parameter_values, inputs[0].dim_order) tosa_graph.addConst( @@ -158,15 +165,16 @@ def process_placeholder( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, edge_program: ExportedProgram, + tosa_spec: TosaSpecification, ): """Wrapper for processing and serializing all types of placeholders""" assert node.name == node.target, "Expect placeholder name and target to match" assert 0 == len(node.args), "Can't handle default input values" if node.name in edge_program.graph_signature.user_inputs: - process_inputs(node, tosa_graph) + process_inputs(node, tosa_graph, tosa_spec) elif node.name in edge_program.graph_signature.inputs_to_parameters: - process_inputs_to_parameters(node, tosa_graph, edge_program) + process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) elif node.name in edge_program.graph_signature.inputs_to_buffers: process_inputs_to_buffers(node, tosa_graph, edge_program) elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants: diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index b0e2a7f0bb..3a9818929b 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -177,16 +177,18 @@ def maybe_get_tosa_collate_path() -> str | None: def get_tosa_compile_spec( - permute_memory_to_nhwc=True, custom_path=None + tosa_version: str, permute_memory_to_nhwc=True, custom_path=None ) -> list[CompileSpec]: """ Default compile spec for TOSA tests. """ - return get_tosa_compile_spec_unbuilt(permute_memory_to_nhwc, custom_path).build() + return get_tosa_compile_spec_unbuilt( + tosa_version, permute_memory_to_nhwc, custom_path + ).build() def get_tosa_compile_spec_unbuilt( - permute_memory_to_nhwc=False, custom_path=None + tosa_version: str, permute_memory_to_nhwc=False, custom_path=None ) -> ArmCompileSpecBuilder: """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify the compile spec before calling .build() to finalize it. @@ -202,7 +204,7 @@ def get_tosa_compile_spec_unbuilt( os.makedirs(intermediate_path, exist_ok=True) compile_spec_builder = ( ArmCompileSpecBuilder() - .tosa_compile_spec() + .tosa_compile_spec(tosa_version) .set_permute_memory_format(permute_memory_to_nhwc) .dump_intermediate_artifacts_to(intermediate_path) ) diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 7d9a18a80e..66e3e52d4d 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -49,7 +49,7 @@ def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -63,7 +63,7 @@ def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -107,7 +107,9 @@ def test_numerical_diff_prints(self): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=False + ), ) .export() .to_edge() @@ -132,7 +134,7 @@ def test_dump_ops_and_dtypes(): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .dump_dtype_distribution() @@ -156,7 +158,7 @@ def test_dump_ops_and_dtypes_parseable(): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .dump_dtype_distribution(print_table=False) @@ -187,7 +189,7 @@ def test_collate_tosa_BI_tests(self): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -217,7 +219,7 @@ def test_dump_tosa_ops(caplog): ArmTester( model, example_inputs=model.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_dim_order_guards.py b/backends/arm/test/misc/test_dim_order_guards.py index 8bad1493b1..d7406afe95 100644 --- a/backends/arm/test/misc/test_dim_order_guards.py +++ b/backends/arm/test/misc/test_dim_order_guards.py @@ -34,7 +34,7 @@ def test_tosa_MI_pipeline(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -48,7 +48,7 @@ def test_tosa_BI_pipeline(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py index 29b2887431..12b8d0665b 100644 --- a/backends/arm/test/misc/test_lifted_tensor.py +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -60,7 +60,7 @@ def test_partition_lifted_tensor_tosa_MI(self, op, data): ArmTester( LiftedTensor(op), example_inputs=data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -77,7 +77,7 @@ def test_partition_lifted_tensor_tosa_BI(self, op, data): ArmTester( LiftedTensor(op), example_inputs=data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -95,7 +95,7 @@ def test_partition_lifted_scalar_tensor_tosa_MI(self, op, data, arg1): ArmTester( LiftedScalarTensor(op, arg1), example_inputs=(data), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -110,7 +110,7 @@ def test_partition_lifted_scalar_tensor_tosa_BI(self, op, data, arg1): ArmTester( LiftedScalarTensor(op, arg1), example_inputs=(data), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py new file mode 100644 index 0000000000..5cbad140b7 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -0,0 +1,105 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + +test_valid_0_80_strings = [ + "TOSA-0.80.0+BI", + "TOSA-0.80.0+MI+8k", + "TOSA-0.80.0+BI+u55", +] +test_valid_1_00_strings = [ + "TOSA-1.00.0+INT+FP+fft", + "TOSA-1.00.0+FP+bf16+fft", + "TOSA-1.00.0+INT+int4+cf", + "TOSA-1.00.0+FP+cf+bf16+8k", + "TOSA-1.00.0+FP+INT+bf16+fft+int4+cf", + "TOSA-1.00.0+FP+INT+fft+int4+cf+8k", +] + +test_valid_1_00_extensions = { + "INT": ["int16", "int4", "var", "cf"], + "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], +} + +test_invalid_strings = [ + "TOSA-0.80.0+bi", + "TOSA-0.80.0", + "TOSA-0.80.0+8k", + "TOSA-0.80.0+BI+MI", + "TOSA-0.80.0+BI+U55", + "TOSA-1.00.0+fft", + "TOSA-1.00.0+fp+bf16+fft", + "TOSA-1.00.0+INT+INT4+cf", + "TOSA-1.00.0+BI", + "TOSA-1.00.0+FP+FP+INT", + "TOSA-1.00.0+FP+CF+bf16", + "TOSA-1.00.0+BF16+fft+int4+cf+INT", +] + +test_compile_specs = [ + ([CompileSpec("tosa_version", "TOSA-0.80.0+BI".encode())],), + ([CompileSpec("tosa_version", "TOSA-0.80.0+BI+u55".encode())],), + ([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],), +] + +test_compile_specs_no_version = [ + ([CompileSpec("other_key", "TOSA-0.80.0+BI".encode())],), + ([CompileSpec("other_key", "some_value".encode())],), +] + + +class TestTosaSpecification(unittest.TestCase): + """Tests the TOSA specification class""" + + @parameterized.expand(test_valid_0_80_strings) + def test_version_string_0_80(self, version_string: str): + tosa_spec = TosaSpecification.create_from_string(version_string) + assert isinstance(tosa_spec, Tosa_0_80) + assert tosa_spec.profile in ["BI", "MI"] + + @parameterized.expand(test_valid_1_00_strings) + def test_version_string_1_00(self, version_string: str): + tosa_spec = TosaSpecification.create_from_string(version_string) + assert isinstance(tosa_spec, Tosa_1_00) + assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count( + True + ) > 0 + + for profile in tosa_spec.profiles: + assert [ + e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions + ] + + @parameterized.expand(test_invalid_strings) + def test_invalid_version_strings(self, version_string: str): + tosa_spec = None + with self.assertRaises(ValueError): + tosa_spec = TosaSpecification.create_from_string(version_string) + + assert tosa_spec is None + + @parameterized.expand(test_compile_specs) + def test_create_from_compilespec(self, compile_specs: list[CompileSpec]): + tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs) + assert isinstance(tosa_spec, TosaSpecification) + + @parameterized.expand(test_compile_specs_no_version) + def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]): + tosa_spec = None + with self.assertRaises(ValueError): + tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs) + + assert tosa_spec is None diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index a50e2732f1..97a802b15d 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -54,7 +54,9 @@ def test_mv2_tosa_MI(self): ArmTester( self.mv2, example_inputs=self.model_inputs, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge(config=self._edge_compile_config) @@ -69,7 +71,9 @@ def test_mv2_tosa_BI(self): ArmTester( self.mv2, example_inputs=self.model_inputs, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index e3eeb187da..66e278ee0f 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -61,7 +61,7 @@ def _test_add_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.add.Tensor": 1}) @@ -80,7 +80,7 @@ def _test_add_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py index 344a80a79b..afd079fb95 100644 --- a/backends/arm/test/ops/test_avg_pool.py +++ b/backends/arm/test/ops/test_avg_pool.py @@ -55,7 +55,9 @@ def _test_avgpool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check(["torch.ops.aten.avg_pool2d.default"]) @@ -76,7 +78,9 @@ def _test_avgpool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index bfe1146a90..297ac0af1c 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -533,7 +533,7 @@ def _test_batchnorm2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) @@ -561,7 +561,7 @@ def _test_batchnorm2d_no_stats_tosa_MI_pipeline( ArmTester( module, example_example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten._native_batch_norm_legit.no_stats": 1}) @@ -590,7 +590,7 @@ def _test_batchnorm2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index e4e6abb7bb..e5e9508e25 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -50,7 +50,7 @@ def _test_bmm_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.bmm.default": 1}) @@ -70,7 +70,7 @@ def _test_bmm_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index 9723ba0f0c..b380c44d52 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -56,7 +56,7 @@ def _test_cat_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.cat.default": 1}) @@ -76,7 +76,7 @@ def _test_cat_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 9852c5c452..4721f257b0 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -47,7 +47,7 @@ def _test_clone_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.clone.default": 1}) @@ -66,7 +66,7 @@ def _test_clone_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py index 3b27554221..133148faef 100644 --- a/backends/arm/test/ops/test_conv1d.py +++ b/backends/arm/test/ops/test_conv1d.py @@ -226,7 +226,9 @@ def _test_conv1d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -246,7 +248,9 @@ def _test_conv1d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 46adfc8a01..43c3e85139 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -253,7 +253,9 @@ def _test_conv2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -273,7 +275,9 @@ def _test_conv2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 4b45b67126..3e9bdef958 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -192,7 +192,9 @@ def _test_conv_combo_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -214,7 +216,9 @@ def _test_conv_combo_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 01ffbc1054..4bfa863c49 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -177,7 +177,9 @@ def _test_dw_conv_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .to_edge() @@ -195,7 +197,9 @@ def _test_dw_conv_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index 84a8d53f9d..28cc686690 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -102,7 +102,7 @@ def _test_div_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.div.Tensor": 1}) @@ -121,7 +121,7 @@ def _test_div_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_exp.py b/backends/arm/test/ops/test_exp.py index 6e85d8fe49..c706b7b206 100644 --- a/backends/arm/test/ops/test_exp.py +++ b/backends/arm/test/ops/test_exp.py @@ -40,7 +40,7 @@ def _test_exp_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.exp.default"]) @@ -58,7 +58,7 @@ def _test_exp_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index aa13a6475c..effa7ce713 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -46,7 +46,7 @@ def _test_expand_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.expand.default": 1}) @@ -64,7 +64,7 @@ def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 2722edef32..d4cfc5c369 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -57,7 +57,7 @@ def _test_full_tosa_MI_pipeline( ArmTester( module, example_inputs=example_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.full.default": 1}) @@ -80,7 +80,7 @@ def _test_full_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute_memory_to_nhwc + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .quantize() diff --git a/backends/arm/test/ops/test_hardtanh.py b/backends/arm/test/ops/test_hardtanh.py index c7c3736e37..a9f12abdf0 100644 --- a/backends/arm/test/ops/test_hardtanh.py +++ b/backends/arm/test/ops/test_hardtanh.py @@ -52,7 +52,7 @@ def _test_hardtanh_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.hardtanh.default"]) @@ -73,7 +73,7 @@ def _test_hardtanh_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 0150c20524..f059d71eba 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -74,7 +74,9 @@ def _test_layernorm_tosa_MI_pipeline( ArmTester( model=module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check(["torch.ops.aten.layer_norm.default"]) @@ -93,7 +95,9 @@ def _test_layernorm_tosa_BI_pipeline( ArmTester( model=module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .check_not(["torch.ops.aten.layer_norm.default"]) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 3f68ab0251..6221af8446 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -122,7 +122,9 @@ def _test_linear_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=False + ), ) .export() .check_count({"torch.ops.aten.linear.default": 1}) @@ -141,7 +143,9 @@ def _test_linear_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=False + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_log.py b/backends/arm/test/ops/test_log.py index 269b7be25f..847635ea36 100644 --- a/backends/arm/test/ops/test_log.py +++ b/backends/arm/test/ops/test_log.py @@ -40,7 +40,7 @@ def _test_log_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.log.default"]) @@ -58,7 +58,7 @@ def _test_log_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py index 2d51588bb3..5d84fa127f 100644 --- a/backends/arm/test/ops/test_logsoftmax.py +++ b/backends/arm/test/ops/test_logsoftmax.py @@ -46,7 +46,7 @@ def _test_logsoftmax_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.log_softmax.int"]) @@ -66,7 +66,7 @@ def _test_logsoftmax_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 5c48afa3ce..41526b1c77 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -62,7 +62,9 @@ def _test_maxpool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check(["torch.ops.aten.max_pool2d.default"]) @@ -87,7 +89,9 @@ def _test_maxpool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 68307bcdf1..e8320cf1df 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -81,7 +81,7 @@ def _test_adaptive_avg_pool2d_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.adaptive_avg_pool2d.default"]) @@ -101,7 +101,7 @@ def _test_adaptive_avg_pool2d_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -150,7 +150,7 @@ def _test_meandim_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) @@ -169,7 +169,7 @@ def _test_meandim_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_mm.py b/backends/arm/test/ops/test_mm.py index 4271496eaa..21b02bbd10 100644 --- a/backends/arm/test/ops/test_mm.py +++ b/backends/arm/test/ops/test_mm.py @@ -54,7 +54,7 @@ def _test_mm_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.mm.default": 1}) @@ -74,7 +74,7 @@ def _test_mm_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index a1c2dba5fe..7fa20c2566 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -70,7 +70,9 @@ def _test_mul_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True + ), ) .export() .check_count({"torch.ops.aten.mul.Tensor": 1}) @@ -89,7 +91,9 @@ def _test_mul_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + compile_spec=common.get_tosa_compile_spec( + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True + ), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index 6346e847c9..62b6b823de 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -57,7 +57,7 @@ def _test_permute_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute_memory_to_nhwc + "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .export() @@ -79,7 +79,7 @@ def _test_permute_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_reciprocal.py b/backends/arm/test/ops/test_reciprocal.py index cb4971bf8c..7745a614e6 100644 --- a/backends/arm/test/ops/test_reciprocal.py +++ b/backends/arm/test/ops/test_reciprocal.py @@ -46,7 +46,7 @@ def _test_reciprocal_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.reciprocal.default": 1}) @@ -65,7 +65,7 @@ def _test_reciprocal_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_relu.py b/backends/arm/test/ops/test_relu.py index effbccc74d..595c907b32 100644 --- a/backends/arm/test/ops/test_relu.py +++ b/backends/arm/test/ops/test_relu.py @@ -48,7 +48,7 @@ def _test_relu_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.relu.default"]) @@ -69,7 +69,7 @@ def _test_relu_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 1efac9f974..20c57ba749 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -47,7 +47,7 @@ def _test_repeat_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.repeat.default": 1}) @@ -65,7 +65,7 @@ def _test_repeat_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_rsqrt.py b/backends/arm/test/ops/test_rsqrt.py index 2ccb7ec991..2cddc8da26 100644 --- a/backends/arm/test/ops/test_rsqrt.py +++ b/backends/arm/test/ops/test_rsqrt.py @@ -35,7 +35,7 @@ def _test_rsqrt_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.rsqrt.default": 1}) @@ -53,7 +53,7 @@ def _test_rsqrt_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index 0305ef58c0..86433745a6 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -123,7 +123,7 @@ def _test_add_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -137,7 +137,7 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py index fdb2fa1463..85bfc15d2d 100644 --- a/backends/arm/test/ops/test_select.py +++ b/backends/arm/test/ops/test_select.py @@ -58,7 +58,7 @@ def _test_select_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute ), ) .export() @@ -84,7 +84,7 @@ def _test_select_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute ), ) .quantize() diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index 4d126b68e5..f12658c985 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -71,7 +71,7 @@ def _test_sigmoid_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.sigmoid.default"]) @@ -89,7 +89,7 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 18db358fdf..0fc92b011a 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -39,7 +39,7 @@ def _test_slice_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.slice.Tensor"]) @@ -60,7 +60,7 @@ def _test_slice_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - permute_memory_to_nhwc=permute + "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute ), ) .quantize() diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index 954dd201a9..f883d6b8de 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -47,7 +47,7 @@ def _test_softmax_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.softmax.int"]) @@ -67,7 +67,7 @@ def _test_softmax_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_split.py b/backends/arm/test/ops/test_split.py index 8ed0e723f1..42395c4c2d 100644 --- a/backends/arm/test/ops/test_split.py +++ b/backends/arm/test/ops/test_split.py @@ -56,7 +56,7 @@ def _test_split_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -79,7 +79,7 @@ def _test_split_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_squeeze.py b/backends/arm/test/ops/test_squeeze.py index c3f1edf37b..7e915da645 100644 --- a/backends/arm/test/ops/test_squeeze.py +++ b/backends/arm/test/ops/test_squeeze.py @@ -61,7 +61,7 @@ def _test_squeeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({export_target: 1}) @@ -82,7 +82,7 @@ def _test_squeeze_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index e80c043698..5c67240e52 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -43,7 +43,7 @@ def _test_sub_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.sub.Tensor": 1}) @@ -63,7 +63,7 @@ def _test_sub_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 73860dfa4a..9cd63b0a22 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -49,7 +49,7 @@ def _test_sum_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.sum.dim_IntList": 1}) @@ -68,7 +68,7 @@ def _test_sum_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_tanh.py b/backends/arm/test/ops/test_tanh.py index 6f5cf17cf3..5f3859eadd 100644 --- a/backends/arm/test/ops/test_tanh.py +++ b/backends/arm/test/ops/test_tanh.py @@ -44,7 +44,7 @@ def _test_tanh_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.tanh.default"]) @@ -62,7 +62,7 @@ def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple) ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_unsqueeze.py b/backends/arm/test/ops/test_unsqueeze.py index 36bb93b796..8936d55f8b 100644 --- a/backends/arm/test/ops/test_unsqueeze.py +++ b/backends/arm/test/ops/test_unsqueeze.py @@ -35,7 +35,7 @@ def _test_unsqueeze_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.unsqueeze.default": 1}) @@ -53,7 +53,7 @@ def _test_unsqueeze_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 56b7c5fbb4..3a1285e6da 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -86,7 +86,7 @@ def _test_var_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .to_edge() @@ -107,7 +107,7 @@ def _test_var_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 54e80702e3..09a8f57bd3 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -55,7 +55,7 @@ def _test_view_tosa_MI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.view.default": 1}) @@ -73,7 +73,7 @@ def _test_view_tosa_BI_pipeline( ArmTester( module, example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py index c8fa0f4b7a..615187fb65 100644 --- a/backends/arm/test/passes/test_meandim_to_averagepool2d.py +++ b/backends/arm/test/passes/test_meandim_to_averagepool2d.py @@ -46,7 +46,7 @@ def test_tosa_BI_meandim_to_averagepool(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() @@ -63,7 +63,7 @@ def test_tosa_BI_meandim_no_modification(self): ArmTester( module, example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() diff --git a/backends/arm/test/quantizer/test_generic_annotater.py b/backends/arm/test/quantizer/test_generic_annotater.py index b859757df4..3d39463a42 100644 --- a/backends/arm/test/quantizer/test_generic_annotater.py +++ b/backends/arm/test/quantizer/test_generic_annotater.py @@ -30,7 +30,9 @@ def example_inputs(self): class TestGenericAnnotator(unittest.TestCase): def check_annotation(self, model): tester = ArmTester( - model, model.example_inputs(), common.get_tosa_compile_spec() + model, + model.example_inputs(), + common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) quant_model = tester.quantize().get_artifact() partitions = get_source_partitions(quant_model.graph, [model.op]) diff --git a/backends/arm/tosa_specification.py b/backends/arm/tosa_specification.py new file mode 100644 index 0000000000..716e8daee2 --- /dev/null +++ b/backends/arm/tosa_specification.py @@ -0,0 +1,226 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# +# Main implementation of AoT flow to partition and preprocess for Arm target +# backends. Converts via TOSA as an intermediate form supported by AoT and +# JIT compiler flows. +# + +import re +from typing import List + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from packaging.version import Version + + +class TosaSpecification: + """ + This class implements a representation of TOSA specification + (https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile + (with extension) and a level (8k). + For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension + For 1.00 releases the profile is INT or FP, and the extensions are for + INT: int16, int4, var, cf + FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf + + The TOSA specification is encoded in the string represenatation + TOSA-major.minor.patch+profile[+level][+extensions] + + For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified. + + Profiles are uppercase letters and extensions and level is lowercase. + """ + + version: Version + + def support_integer(self) -> bool: + """ + Returns true if any integer operations are supported for the specification. + """ + raise NotImplementedError + + def support_float(self) -> bool: + """ + Returns true if any float operations are supported for the specification. + """ + raise NotImplementedError + + def __init__(self, version: Version): + self.version = version + + @staticmethod + def create_from_compilespecs( + compile_specs: List[CompileSpec], + ) -> "TosaSpecification": + """ + Search the CompileSpec list for 'tosa_version' and instantiate a + class from the found value or return None on failure. + """ + for spec in compile_specs: + if spec.key == "tosa_version": + return TosaSpecification.create_from_string(spec.value.decode()) + raise ValueError( + "No TOSA version key found in any of the supplied CompileSpecs" + ) + + @staticmethod + def create_from_string(repr: str) -> "TosaSpecification": + """ + Creates a TOSA specification class from a string representation: + TOSA-0.80.0+MI + TOSA-0.80.0+BI+8k + TOSA-0.80.0+BI+u55 # Ethos-U55 extension to handle TOSA subset + TOSA-0.90.0+MI + TOSA-1.00.0+INT+FP+int4+cf + """ + + pattern = r"^(TOSA)-([\d.]+)\+(.+)$" + match = re.match(pattern, repr) + if match: + name = match.group(1) + version = Version(match.group(2)) + extras = match.group(3).split("+") + if name != "TOSA": + raise ValueError(f"Malformed TOSA specification representation: {repr}") + match version: + case _ if version.major == 0 and version.minor == 80: + return Tosa_0_80(version, extras) + case _ if version.major == 1 and version.minor == 0: + return Tosa_1_00(version, extras) + case _: + raise ValueError(f"Wrong TOSA version: {version} from {repr}") + + raise ValueError(f"Failed to parse TOSA specification representation: {repr}") + + +class Tosa_0_80(TosaSpecification): + profile: str + level_8k: bool + is_U55_subset: bool + available_profiles = ["BI", "MI"] # MT is not defined + + def __init__(self, version: Version, extras: List[str]): + super().__init__(version) + assert version >= Version("0.80") and version < Version("0.90") + + # Check that we only have one profile in the extensions list + if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1: + raise ValueError( + f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found." + ) + + # The list contains one profile at most, so pick it + self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0] + extras.remove(self.profile) + + self.level_8k = "8k" in extras + if self.level_8k: + extras.remove("8k") + self.is_U55_subset = "u55" in extras + if self.is_U55_subset: + extras.remove("u55") + + if len(extras) > 0: + raise ValueError(f"Unhandled extras found: {extras}") + + def __repr__(self): + extensions = "" + if self.level_8k: + extensions += "+8K" + if self.is_U55_subset: + extensions += "+u55" + return f"TOSA-{str(self.version)}+{self.profile}{extensions}" + + def __hash__(self) -> int: + return hash(str(self.version) + self.profile) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Tosa_0_80): + return (self.version == other.version) and (self.profile == other.profile) + return False + + def support_integer(self): + return True + + def support_float(self): + return self.profile == "MI" + + +class Tosa_1_00(TosaSpecification): + profiles: List[str] + level_8k: bool + extensions: List[str] + + available_profiles = ["INT", "FP"] + valid_extensions = { + "INT": ["int16", "int4", "var", "cf"], + "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], + } + + def __init__(self, version: Version, extras: List[str]): + super().__init__(version) + + # Check that we have at least one profile in the extensions list + if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0: + raise ValueError( + f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}." + ) + + # and not more than number of available profiles + if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len( + Tosa_1_00.available_profiles + ): + raise ValueError( + f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}." + ) + + # The list contains one profile at least, so pick them + self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles] + for p in self.profiles: + extras.remove(p) + + self.level_8k = "8k" in extras + if self.level_8k: + extras.remove("8k") + + combined_extensions = [] + for p in self.profiles: + combined_extensions += Tosa_1_00.valid_extensions[p] + + if not all(e in combined_extensions for e in extras): + raise ValueError( + f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}" + ) + + # all the rest of the extras are handled extensions + self.extensions = extras + + def _get_profiles_string(self) -> str: + return "".join(["+" + p for p in self.profiles]) + + def _get_extensions_string(self) -> str: + return "".join(["+" + e for e in self.extensions]) + + def __repr__(self): + return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}" + + def __hash__(self) -> int: + return hash(str(self.version) + self._get_profiles_string()) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Tosa_1_00): + return (self.version == other.version) and ( + self._get_profiles_string() == other._get_profiles_string() + ) + return False + + def support_integer(self): + return "INT" in self.profiles + + def support_float(self): + return "FP" in self.profiles diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index cfafac1676..bf60aaf0f8 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -21,6 +21,7 @@ is_quant_node, q_op, ) +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -290,6 +291,7 @@ def process_call_function( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, node_visitors: Dict[str, NodeVisitor], + tosa_spec: TosaSpecification, ): # Unpack arguments and convert inputs = getNodeArgs(node) @@ -319,7 +321,7 @@ def process_call_function( is_quant_node(node), ) else: - raise RuntimeError(f"Unknown operator {node.target}") + raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") def expand_dims( diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 3075d992d5..e718c52fdc 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -180,7 +180,9 @@ def get_compile_spec( spec_builder = None if target == "TOSA": spec_builder = ( - ArmCompileSpecBuilder().tosa_compile_spec().set_permute_memory_format(True) + ArmCompileSpecBuilder() + .tosa_compile_spec("TOSA-0.80.0+BI") + .set_permute_memory_format(True) ) elif "ethos-u55" in target: spec_builder = ( From b1e6617f37747f7f5322ccb473b8983ea598a675 Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:27:58 -0800 Subject: [PATCH 37/59] Fix pyre Differential Revision: D65684570 Pull Request resolved: https://github.com/pytorch/executorch/pull/6740 --- backends/arm/_passes/TARGETS | 1 + backends/arm/operators/op_max_pool2d.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index ca20b03fcc..6ca59cfee2 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -7,6 +7,7 @@ python_library( deps = [ "//executorch/backends/arm:tosa_quant_utils", "//executorch/backends/arm:tosa_utils", + "//executorch/backends/xnnpack/_passes:xnnpack_passes", "//executorch/exir:lib", ], ) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 0752d8242f..a0b868f684 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import cast, List +from typing import List import serializer.tosa_serializer as ts import torch @@ -54,9 +54,7 @@ def define_node( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args( - cast(torch.fx.Node, node.all_input_nodes[0]) - ).zp + input_zp = get_quant_node_args(node.all_input_nodes[0]).zp output_zp = get_quant_node_args(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() From 289e84edde3e17c6276a3c57c07daf09a88f8bc0 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 8 Nov 2024 22:28:28 -0800 Subject: [PATCH 38/59] Correctly set _GLIBCXX_USE_CXX11_ABI pybind compile options (#6744) * Update torchao pinned commit to latest * Correctly set _GLIBCXX_USE_CXX11_ABI pybind compile options * Remove unwanted change * Update CMakeLists.txt Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> * Update CMakeLists.txt to remove the empty else branch --------- Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- CMakeLists.txt | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 156fb24e6b..6b76f27eb0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -721,10 +721,15 @@ if(EXECUTORCH_BUILD_PYBIND) -fPIC -frtti -fexceptions - # libtorch is built with the old ABI, so we need to do the same for any - # .cpp files that include torch, c10, or ATen targets. - -D_GLIBCXX_USE_CXX11_ABI=0 ) + if(EXECUTORCH_DO_NOT_USE_CXX11_ABI) + # libtorch is built with the old ABI, so we need to do the same for any + # .cpp files that include torch, c10, or ATen targets. Note that PyTorch + # nightly binary is built with _GLIBCXX_USE_CXX11_ABI set to 0 while its + # CI build sets this to 1 (default) + list(APPEND _pybind_compile_options -D_GLIBCXX_USE_CXX11_ABI=0) + endif() + # util lib add_library( util ${CMAKE_CURRENT_SOURCE_DIR}/extension/evalue_util/print_evalue.cpp From 427b36d09a0e367b5a54876daecfd5cd78fb1e43 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Sat, 9 Nov 2024 18:14:54 -0800 Subject: [PATCH 39/59] Add Android standalone log target Differential Revision: D65268257 Pull Request resolved: https://github.com/pytorch/executorch/pull/6590 --- extension/android/CMakeLists.txt | 4 +- extension/android/jni/BUCK | 23 +++++- extension/android/jni/jni_layer.cpp | 112 ++++++---------------------- extension/android/jni/log.cpp | 69 +++++++++++++++++ extension/android/jni/log.h | 43 +++++++++++ 5 files changed, 156 insertions(+), 95 deletions(-) create mode 100644 extension/android/jni/log.cpp create mode 100644 extension/android/jni/log.h diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index c96cfeb5d7..70f21f2751 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch) find_package(executorch CONFIG REQUIRED) target_link_options_shared_lib(executorch) -add_library(executorch_jni SHARED jni/jni_layer.cpp) +add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp) set(link_libraries) list( @@ -146,7 +146,7 @@ if(EXECUTORCH_JNI_CUSTOM_LIBRARY) endif() if(EXECUTORCH_BUILD_LLAMA_JNI) - target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp) + target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) list(APPEND link_libraries llama_runner llava_runner) target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) add_subdirectory( diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index 6f269739c0..e1bf26fef2 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -1,5 +1,6 @@ load("@fbsource//tools/build_defs/android:fb_android_cxx_library.bzl", "fb_android_cxx_library") load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") oncall("executorch") @@ -25,7 +26,7 @@ executorch_generated_lib( fb_android_cxx_library( name = "executorch_jni", - srcs = ["jni_layer.cpp"], + srcs = ["jni_layer.cpp", "log.cpp"], headers = ["jni_layer_constants.h"], allow_jni_merging = False, compiler_flags = [ @@ -36,6 +37,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", "//third-party/glog:glog", @@ -49,7 +51,7 @@ fb_android_cxx_library( fb_android_cxx_library( name = "executorch_jni_full", - srcs = ["jni_layer.cpp"], + srcs = ["jni_layer.cpp", "log.cpp"], headers = ["jni_layer_constants.h"], allow_jni_merging = False, compiler_flags = [ @@ -60,6 +62,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", ":generated_op_lib_optimized_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", @@ -88,6 +91,7 @@ fb_android_cxx_library( soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ + ":log_provider_static", "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", "//third-party/glog:glog", @@ -101,3 +105,18 @@ fb_android_cxx_library( "//xplat/executorch/extension/threadpool:threadpool_static", ], ) + +runtime.cxx_library( + name = "log_provider", + srcs = ["log.cpp"], + exported_headers = ["log.h"], + compiler_flags = [ + "-frtti", + "-fexceptions", + "-Wno-unused-variable", + ], + deps = [ + "//executorch/runtime/core:core", + ], + visibility = ["@EXECUTORCH_CLIENTS"], +) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 479da28806..ddba8462b9 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -17,6 +17,7 @@ #include "jni_layer_constants.h" +#include #include #include #include @@ -36,76 +37,6 @@ using namespace executorch::extension; using namespace torch::executor; -#ifdef __ANDROID__ -#include -#include -#include - -// Number of entries to store in the in-memory log buffer. -const size_t log_buffer_length = 16; - -struct log_entry { - et_timestamp_t timestamp; - et_pal_log_level_t level; - std::string filename; - std::string function; - size_t line; - std::string message; - - log_entry( - et_timestamp_t timestamp, - et_pal_log_level_t level, - const char* filename, - const char* function, - size_t line, - const char* message, - size_t length) - : timestamp(timestamp), - level(level), - filename(filename), - function(function), - line(line), - message(message, length) {} -}; - -namespace { -std::vector log_buffer_; -std::mutex log_buffer_mutex_; -} // namespace - -// For Android, write to logcat -void et_pal_emit_log_message( - et_timestamp_t timestamp, - et_pal_log_level_t level, - const char* filename, - const char* function, - size_t line, - const char* message, - size_t length) { - std::lock_guard guard(log_buffer_mutex_); - - while (log_buffer_.size() >= log_buffer_length) { - log_buffer_.erase(log_buffer_.begin()); - } - - log_buffer_.emplace_back( - timestamp, level, filename, function, line, message, length); - - int android_log_level = ANDROID_LOG_UNKNOWN; - if (level == 'D') { - android_log_level = ANDROID_LOG_DEBUG; - } else if (level == 'I') { - android_log_level = ANDROID_LOG_INFO; - } else if (level == 'E') { - android_log_level = ANDROID_LOG_ERROR; - } else if (level == 'F') { - android_log_level = ANDROID_LOG_FATAL; - } - - __android_log_print(android_log_level, "ExecuTorch", "%s", message); -} -#endif - namespace executorch::extension { class TensorHybrid : public facebook::jni::HybridClass { public: @@ -437,24 +368,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref> readLogBuffer() { #ifdef __ANDROID__ - std::lock_guard guard(log_buffer_mutex_); - - const auto size = log_buffer_.size(); - facebook::jni::local_ref> ret = - facebook::jni::JArrayClass::newArray(size); - - for (auto i = 0u; i < size; i++) { - const auto& entry = log_buffer_[i]; - // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL MESSAGE". - std::stringstream ss; - ss << "[" << entry.timestamp << " " << entry.function << " " - << entry.filename << ":" << entry.line << "] " - << static_cast(entry.level) << " " << entry.message; - - facebook::jni::local_ref jstr_message = - facebook::jni::make_jstring(ss.str().c_str()); - (*ret)[i] = jstr_message; - } + + facebook::jni::local_ref> ret; + + access_log_buffer([&](std::vector& buffer) { + const auto size = buffer.size(); + ret = facebook::jni::JArrayClass::newArray(size); + for (auto i = 0u; i < size; i++) { + const auto& entry = buffer[i]; + // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL + // MESSAGE". + std::stringstream ss; + ss << "[" << entry.timestamp << " " << entry.function << " " + << entry.filename << ":" << entry.line << "] " + << static_cast(entry.level) << " " << entry.message; + + facebook::jni::local_ref jstr_message = + facebook::jni::make_jstring(ss.str().c_str()); + (*ret)[i] = jstr_message; + } + }); return ret; #else @@ -468,10 +401,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("forward", ExecuTorchJni::forward), makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), - -#ifdef __ANDROID__ makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer), -#endif }); } }; diff --git a/extension/android/jni/log.cpp b/extension/android/jni/log.cpp new file mode 100644 index 0000000000..663198e127 --- /dev/null +++ b/extension/android/jni/log.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "log.h" + +#ifdef __ANDROID__ + +#include +#include +#include +#include + +using executorch::extension::log_entry; + +// Number of entries to store in the in-memory log buffer. +const size_t log_buffer_length = 16; + +namespace { +std::vector log_buffer_; +std::mutex log_buffer_mutex_; +} // namespace + +// For Android, write to logcat +void et_pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) { + std::lock_guard guard(log_buffer_mutex_); + + while (log_buffer_.size() >= log_buffer_length) { + log_buffer_.erase(log_buffer_.begin()); + } + + log_buffer_.emplace_back( + timestamp, level, filename, function, line, message, length); + + int android_log_level = ANDROID_LOG_UNKNOWN; + if (level == 'D') { + android_log_level = ANDROID_LOG_DEBUG; + } else if (level == 'I') { + android_log_level = ANDROID_LOG_INFO; + } else if (level == 'E') { + android_log_level = ANDROID_LOG_ERROR; + } else if (level == 'F') { + android_log_level = ANDROID_LOG_FATAL; + } + + __android_log_print(android_log_level, "ExecuTorch", "%s", message); +} + +namespace executorch::extension { + +void access_log_buffer(std::function&)> accessor) { + std::lock_guard guard(log_buffer_mutex_); + accessor(log_buffer_); +} + +} // namespace executorch::extension + +#endif diff --git a/extension/android/jni/log.h b/extension/android/jni/log.h new file mode 100644 index 0000000000..4389b1d61a --- /dev/null +++ b/extension/android/jni/log.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace executorch::extension { +struct log_entry { + et_timestamp_t timestamp; + et_pal_log_level_t level; + std::string filename; + std::string function; + size_t line; + std::string message; + + log_entry( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) + : timestamp(timestamp), + level(level), + filename(filename), + function(function), + line(line), + message(message, length) {} +}; + +void access_log_buffer(std::function&)> accessor); +} // namespace executorch::extension From 5b51bb8b676ee79a1a0aeb52869a71c6fad6a291 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Sat, 9 Nov 2024 18:25:50 -0800 Subject: [PATCH 40/59] Support sym round and ceil Differential Revision: D65382714 Pull Request resolved: https://github.com/pytorch/executorch/pull/6699 --- exir/pass_base.py | 6 ++- exir/passes/__init__.py | 2 +- exir/passes/executorch_prim_ops_registry.py | 17 +++++++- kernels/prim_ops/register_prim_ops.cpp | 45 +++++++++++++++++++++ kernels/prim_ops/test/prim_ops_test.cpp | 41 +++++++++++++++++++ 5 files changed, 107 insertions(+), 4 deletions(-) diff --git a/exir/pass_base.py b/exir/pass_base.py index db6bef8e3f..9c97921f51 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -318,7 +318,11 @@ def call_function( if target == operator.getitem: value, key = args return self.callback.call_getitem(value, key, meta) - elif getattr(target, "__module__", None) in {"_operator", "math"}: + elif getattr(target, "__module__", None) in { + "_operator", + "builtins", + "math", + }: assert callable(target) return self.callback.call_sym(target, args, meta) elif target in _TORCH_SYM_OPS: diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 7a0623040f..fdb954010c 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -339,7 +339,7 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: self.call(get_submodule(node.args[0])) self.call(get_submodule(node.args[1])) continue - elif getattr(target, "__module__", None) == "_operator": + elif getattr(target, "__module__", None) in ("builtins", "_operator"): continue elif target in to_out_var_skiplist: continue diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index 4af233aaa6..fa1c2e6913 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import builtins import math import operator -from typing import Dict, Set, Union +from typing import Any, Dict, Set, Union # necessary to ensure the ops are registered import torch @@ -94,12 +95,24 @@ def neg(a: _SymScalar) -> _SymScalar: return -a # pyre-ignore +@bind_pattern_to_op(executorch_prims_lib, "ceil.Scalar(Scalar a) -> Scalar") +def ceil(a: _SymScalar) -> _SymScalar: + return math.ceil(a) # pyre-ignore + + +@bind_pattern_to_op(executorch_prims_lib, "round.Scalar(Scalar a) -> Scalar") +def builtin_round(a: _SymScalar) -> _SymScalar: + return round(a) # pyre-ignore + + @bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar") def trunc(a: _SymScalar) -> _SymScalar: return math.trunc(a) # pyre-ignore -_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = { +_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = { + builtins.round: ops.backend.executorch_prim.round.Scalar, + math.ceil: ops.backend.executorch_prim.ceil.Scalar, math.trunc: ops.backend.executorch_prim.trunc.Scalar, operator.sub: ops.backend.executorch_prim.sub.Scalar, operator.mul: ops.backend.executorch_prim.mul.Scalar, diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 5755ab8d66..38901bb840 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -303,6 +303,51 @@ static Kernel prim_ops[] = { } }), + // ceil.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::ceil.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + out = EValue(static_cast(ceil(a.toDouble()))); + } else { + ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + } + }), + + // round.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::round.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + // Round half to even to match Python round(). Need an explicit + // implementation as not all platforms support fenv rounding modes. + // See + // https://codeyarns.com/tech/2018-08-17-how-to-round-half-to-even.html + const auto val = a.toDouble(); + const auto r = round(val); + const auto d = r - val; + auto res = 0.0; + + if (std::abs(d) != 0.5) { + res = r; + } else if (fmod(r, 2.0) == 0.0) { + res = r; + } else { + res = val - d; + } + + out = EValue(static_cast(res)); + } else { + ET_CHECK_MSG(false, "Unsupported DType %zu", (size_t)a.tag); + } + }), + // trunc.Scalar(Scalar a) -> Scalar Kernel( "executorch_prim::trunc.Scalar", diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 3581a470da..ab6bd28e6c 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -503,6 +503,47 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) { getOpsFn("executorch_prim::et_view.default")(context, bad_stack), ""); } +TEST_F(RegisterPrimOpsTest, TestCeil) { + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 1, 1, 1, 1, 2, 0, -1, -1, 10}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::ceil.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + +TEST_F(RegisterPrimOpsTest, TestRound) { + // Note that Python uses round-to-even for halfway values. + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.5, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 0, 0, 1, 1, 2, 0, -1, -2, 10}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::round.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + TEST_F(RegisterPrimOpsTest, TestTrunc) { std::array inputs = { 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; From 0a20d7205c9247b947dbef700fb6dcdbc4e640cb Mon Sep 17 00:00:00 2001 From: Oscar Andersson <87121123+oscarandersson8218@users.noreply.github.com> Date: Mon, 11 Nov 2024 10:27:25 +0100 Subject: [PATCH 41/59] Arm backend: Add linear decomposition (#6661) Add linear decomposition - Linear is decomposed to conv2d by Arm backend. - Enable nn.Linear(..., bias=False). - Remove op_addmm and related helper functions. - Remove unused helper functions. Signed-off-by: Oscar Andersson --- backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/decompose_linear_pass.py | 112 +++++++++++++ backends/arm/arm_partitioner.py | 13 +- backends/arm/operators/__init__.py | 1 - backends/arm/operators/op_add.py | 6 +- backends/arm/operators/op_addmm.py | 148 ------------------ backends/arm/operators/op_permute.py | 8 - backends/arm/operators/op_placeholder.py | 26 +-- backends/arm/operators/op_sub.py | 6 +- backends/arm/test/misc/test_debug_feats.py | 31 ++-- .../arm/test/models/test_mobilenet_v2_arm.py | 16 +- backends/arm/test/ops/test_linear.py | 47 ++++-- backends/arm/test/tester/arm_tester.py | 141 +++++++++++------ backends/arm/tosa_utils.py | 61 +------- 14 files changed, 275 insertions(+), 343 deletions(-) create mode 100644 backends/arm/_passes/decompose_linear_pass.py delete mode 100644 backends/arm/operators/op_addmm.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index a6c9cf1d06..a72cdfd1a0 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -23,6 +23,7 @@ from executorch.backends.arm._passes.decompose_layernorm_pass import ( DecomposeLayerNormPass, ) +from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_softmaxes_pass import ( DecomposeSoftmaxesPass, @@ -74,6 +75,7 @@ def transform_to_backend_pipeline( self.add_pass(ConvertSplitToSlicePass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) + self.add_pass(DecomposeLinearPass()) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py new file mode 100644 index 0000000000..30767b354d --- /dev/null +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -0,0 +1,112 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeLinearPass(ExportPass): + """ + This pass decomposes linear into a Conv2D with the required view operations. + linear(x, weights, bias) becomes: + x_reshaped = view(x) + weights_reshaped = view(weights) + conv2d = conv2d(x_reshaped, weights_reshaped, bias) + output = view(conv2d) + It also inserts q/dq pairs if the linear node was quantized. + """ + + def call(self, graph_module): + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target != exir_ops.edge.aten.linear.default: + continue + args = node.args + input = args[0] + weights = args[1] + bias = args[2] if len(args) > 2 else None + output_shape = get_first_fake_tensor(node).shape + input_shape = get_first_fake_tensor(input).shape + weights_shape = get_first_fake_tensor(weights).shape + batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1 + # input has shape (..., Ci) + input_reshaped_shape = [batches, input_shape[-1], 1, 1] + # weights have shape (Co, Ci) + weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1] + + with graph_module.graph.inserting_before(node): + quantize = input.op == "call_function" and input.target == dq_op + q_params = input.args[1:] if quantize else None + # Reshape input to 4D with shape (N, Ci, 1, 1) + input_reshaped = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(input, input_reshaped_shape), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + quantize = weights.op == "call_function" and weights.target == dq_op + q_params = weights.args[1:] if quantize else None + # Reshape weights to 4D with shape (Co, Ci, 1, 1) + weights_reshaped = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(weights, weights_reshaped_shape), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + consumer_node = list(node.users)[0] + quantize = ( + consumer_node.op == "call_function" and consumer_node.target == q_op + ) + q_params = consumer_node.args[1:] if quantize else None + conv = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.convolution.default, + args=( + input_reshaped, + weights_reshaped, + bias, + [1, 1], # strides + [0, 0], # padding + [1, 1], # dilation + False, # transposed + [0, 0], # output padding + 1, # groups + ), + kwargs={}, + quantize=quantize, + q_params=q_params, + ) + + with graph_module.graph.inserting_after(conv): + # Reshape output to same rank as original input with shape (..., Co) + # No need to insert q/dq pair as Conv2D node above has inserted them if + # required. + output = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(conv, list(output_shape)), + kwargs={}, + ) + + node.replace_all_uses_with(output) + graph_module.graph.erase_node(node) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index bdd4b80f29..ef924fa434 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -8,7 +8,7 @@ import logging import operator import os -from typing import cast, final, List +from typing import Callable, cast, final, List, Optional, Tuple import torch from executorch.backends.arm.arm_backend import ArmBackend # usort: skip @@ -39,7 +39,6 @@ class TOSASupportedOperators(OperatorSupportBase): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: supported = node.op == "call_function" and node.target in [ exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.bmm.default, @@ -49,6 +48,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.log.default, + exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, @@ -137,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags ) + + def ops_to_not_decompose( + self, + ep: ExportedProgram, + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + ops_to_not_decompose = [ + torch.ops.aten.linear.default, + ] + return (ops_to_not_decompose, None) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 5e188aea77..6e51c2c141 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -8,7 +8,6 @@ from . import ( # noqa node_visitor, op_add, - op_addmm, op_avg_pool2d, op_batch_norm, op_bmm, diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index d0518c7a5e..7a71a0d2bd 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -62,10 +62,8 @@ def define_node( input_nodes, tosa_graph ) - # Preapre sub output tensor - broadcasted_shape = tutils.broadcast_shapes( - rescaled_inputs[0].shape, rescaled_inputs[0].shape - ) + # Prepare add output tensor + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) else: add_output = output diff --git a/backends/arm/operators/op_addmm.py b/backends/arm/operators/op_addmm.py deleted file mode 100644 index b4f782db4a..0000000000 --- a/backends/arm/operators/op_addmm.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args - -from executorch.backends.arm.tosa_utils import build_reshape -from executorch.exir.dialects._ops import ops as exir_ops -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class AddmmVisitor(NodeVisitor): - target = "aten.addmm.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - bias, input, weight = inputs - - N = input.shape[0] - input_channels = input.shape[1] - output_channels = weight.shape[1] - - input_new_shape = (N, 1, 1, input_channels) - input_reshaped = tosa_graph.addIntermediate( - input_new_shape, - ts.DType.INT8 if is_quant_node else input.dtype, - ) - - build_reshape(tosa_graph, input.name, input_new_shape, input_reshaped.name) - - weight_new_shape = (output_channels, 1, 1, input_channels) - weight_reshaped = tosa_graph.addIntermediate( - weight_new_shape, - ts.DType.INT8 if is_quant_node else weight.dtype, - ) - - build_reshape(tosa_graph, weight.name, weight_new_shape, weight_reshaped.name) - - # Get the attributes of convolution. - attr = ts.TosaSerializerAttribute() - pad_attr = [0, 0, 0, 0] - stride_attr = [1, 1] - dilation_attr = [1, 1] - - input_zp = 0 - if is_quant_node: - input_node = node.all_input_nodes[1] - # rank > 2 linear layer - if input_node.target == exir_ops.edge.aten.view_copy.default: - quant_node = input_node.all_input_nodes[0] - else: - quant_node = input_node - input_zp = get_quant_node_args(quant_node).zp - attr.ConvAttribute( - pad=pad_attr, - stride=stride_attr, - dilation=dilation_attr, - input_zp=input_zp, - weight_zp=0, - local_bound=False, - ) - - conv2d_output_shape = (N, 1, 1, output_channels) - conv2d_res = tosa_graph.addIntermediate( - conv2d_output_shape, - ts.DType.INT32 if is_quant_node else output.dtype, - ) - - # U55 doesn't support tosa.matmul and tosa.fully_connected will be deprecated - # TOSA Conv2d input is NHWC and weights are in OHWI - tosa_graph.addOperator( - TosaOp.Op().CONV2D, - [ - input_reshaped.name, - weight_reshaped.name, - bias.name, - ], - [conv2d_res.name], - attr, - ) - - result_shape = (N, output_channels) - - if is_quant_node: - # Read inputs' parent nodes - _, input_node, weight_node = node.all_input_nodes - - # rank > 2 linear layer - if input_node.target == exir_ops.edge.aten.view_copy.default: - quant_node = input_node.all_input_nodes[0] - input_scale = get_quant_node_args(quant_node).scale - consumer_node = list(node.users)[0] - consumer_consumer_node = list(consumer_node.users)[0] - quant_args = get_quant_node_args(consumer_consumer_node) - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp - else: - input_scale = get_quant_node_args(input_node).scale - consumer_node = list(node.users)[0] - quant_args = get_quant_node_args(consumer_node) - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp - - weight_node_q_node = weight_node.all_input_nodes[0] - weight_scale = get_quant_node_args(weight_node_q_node).scale - - output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale - - reshaped_res = tosa_graph.addIntermediate(result_shape, ts.DType.INT32) - build_reshape(tosa_graph, conv2d_res.name, result_shape, reshaped_res.name) - - build_rescale( - tosa_fb=tosa_graph, - scale=output_rescale_scale, - input_node=reshaped_res, - output_name=output.name, - output_type=ts.DType.INT8, - output_shape=reshaped_res.shape, - input_zp=0, - output_zp=consumer_node_node_zp, - is_double_round=False, - ) - - else: - # non-quantized case - build_reshape(tosa_graph, conv2d_res.name, result_shape, output.name) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 69f6f6506c..8142d6d654 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -14,7 +14,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import is_permute_node_before_addmm from serializer.tosa_serializer import TosaOp @@ -81,13 +80,6 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_permute_node_before_addmm(node): - ## Simply add an identityOp - tosa_graph.addOperator( - TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] - ) - return - # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index f3e52e68f7..950d4636d2 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -16,11 +16,9 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import ( - is_bias_node_for_quantized_addmm, is_bias_node_for_quantized_conv, tosa_shape, ) -from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram @@ -57,25 +55,13 @@ def process_quantized_bias( ): """ Serialize bias node that needs to be quantized. - This can be either an addmm or conv bias node. """ consumer_node = list(node.users)[0] - if is_bias_node_for_quantized_addmm(node): - ( - _, - input_node, - weight_node_permuted, - ) = consumer_node.all_input_nodes - - weight_node = weight_node_permuted.all_input_nodes[0] - if input_node.target == exir_ops.edge.aten.view_copy.default: - input_node = input_node.all_input_nodes[0] - else: - ( - input_node, - weight_node, - _, - ) = consumer_node.all_input_nodes + ( + input_node, + weight_node, + _, + ) = consumer_node.all_input_nodes input_node_scale = get_quant_node_args(input_node).scale weight_node_scale = get_quant_node_args(weight_node).scale @@ -107,7 +93,7 @@ def process_inputs_to_parameters( assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor" parameter_values = parameter_data.detach().numpy() - if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node): + if is_bias_node_for_quantized_conv(node): # BI bias assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer" process_quantized_bias(node, tosa_graph, parameter_values) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 2089b6e9e9..b86a5ea3ad 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -43,10 +43,8 @@ def define_node( input_nodes, tosa_graph ) - # Preapre sub output tensor - broadcasted_shape = tutils.broadcast_shapes( - rescaled_inputs[0].shape, rescaled_inputs[0].shape - ) + # Prepare sub output tensor + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) # Do the INT32 Sub diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 66e3e52d4d..4cac39af70 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -67,8 +67,7 @@ def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None): ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_artifact(dump_file) .dump_artifact() ) @@ -108,12 +107,11 @@ def test_numerical_diff_prints(self): model, example_inputs=model.get_inputs(), compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=False + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True ), ) .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .to_executorch() ) # We expect an assertion error here. Any other issues will cause the @@ -142,10 +140,7 @@ def test_dump_ops_and_dtypes(): .export() .dump_dtype_distribution() .dump_operator_distribution() - .to_edge() - .dump_dtype_distribution() - .dump_operator_distribution() - .partition() + .to_edge_transform_and_lower() .dump_dtype_distribution() .dump_operator_distribution() ) @@ -166,10 +161,7 @@ def test_dump_ops_and_dtypes_parseable(): .export() .dump_dtype_distribution(print_table=False) .dump_operator_distribution(print_table=False) - .to_edge() - .dump_dtype_distribution(print_table=False) - .dump_operator_distribution(print_table=False) - .partition() + .to_edge_transform_and_lower() .dump_dtype_distribution(print_table=False) .dump_operator_distribution(print_table=False) ) @@ -193,8 +185,7 @@ def test_collate_tosa_BI_tests(self): ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .to_executorch() ) # test that the output directory is created and contains the expected files @@ -202,10 +193,10 @@ def test_collate_tosa_BI_tests(self): "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests" ) assert os.path.exists( - "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag8.tosa" + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag5.tosa" ) assert os.path.exists( - "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag8.json" + "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag5.json" ) os.environ.pop("TOSA_TESTCASES_BASE_PATH") @@ -223,8 +214,7 @@ def test_dump_tosa_ops(caplog): ) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_operator_distribution() ) assert "TOSA operators:" in caplog.text @@ -243,8 +233,7 @@ def forward(self, x): ArmTester(model, example_inputs=(torch.ones(5),), compile_spec=compile_spec) .quantize() .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_operator_distribution() ) assert "Can not get operator distribution for Vela command stream." in caplog.text diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index 97a802b15d..19b4254575 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -59,9 +59,7 @@ def test_mv2_tosa_MI(self): ), ) .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.all_operators)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .run_method_and_compare_outputs(inputs=self.model_inputs) ) @@ -77,9 +75,7 @@ def test_mv2_tosa_BI(self): ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() # atol=1.0 is a defensive upper limit # TODO MLETROCH-72 @@ -96,9 +92,7 @@ def test_mv2_u55_BI(self): ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .serialize() ) @@ -116,9 +110,7 @@ def test_mv2_u85_BI(self): ) .quantize() .export() - .to_edge(config=self._edge_compile_config) - .check(list(self.operators_after_quantization)) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .to_executorch() .serialize() ) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 6221af8446..c7a475035d 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -23,70 +23,82 @@ test_data_suite_rank1 = [ - # (test_name, test_data, out_features) + # (test_name, test_data, out_features, has_bias) ( "model_linear_rank1_zeros", torch.zeros(10), 15, + True, ), ( "model_linear_rank1_ones", torch.ones(10), 15, + False, ), ( "model_linear_rank1_negative_ones", torch.ones(10) * (-1), 20, + True, ), ( "model_linear_rank1_rand", torch.rand(10), 10, + True, ), ( "model_linear_rank1_negative_large_rand", torch.rand(10) * (-100), 30, + False, ), ( "model_linear_rank1_large_randn", torch.randn(15) * 100, 20, + True, ), ] test_data_suite_rank4 = [ - # (test_name, test_data, out_features) + # (test_name, test_data, out_features, has_bias) ( "model_linear_rank4_zeros", torch.zeros(5, 10, 25, 20), 30, + True, ), ( "model_linear_rank4_ones", torch.ones(5, 10, 25, 20), 30, + False, ), ( "model_linear_rank4_negative_ones", torch.ones(5, 10, 25, 20) * (-1), 30, + True, ), ( "model_linear_rank4_rand", torch.rand(5, 10, 25, 20), 30, + False, ), ( "model_linear_rank4_negative_large_rand", torch.rand(5, 10, 25, 20) * (-100), 30, + True, ), ( "model_linear_rank4_large_randn", torch.randn(5, 10, 25, 20) * 100, 30, + False, ), ] @@ -123,14 +135,13 @@ def _test_linear_tosa_MI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+MI", permute_memory_to_nhwc=False + "TOSA-0.80.0+MI", permute_memory_to_nhwc=True ), ) .export() .check_count({"torch.ops.aten.linear.default": 1}) .check_not(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) @@ -144,15 +155,14 @@ def _test_linear_tosa_BI_pipeline( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80.0+BI", permute_memory_to_nhwc=False + "TOSA-0.80.0+BI", permute_memory_to_nhwc=True ), ) .quantize() .export() .check_count({"torch.ops.aten.linear.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data, qtol=True) @@ -174,8 +184,7 @@ def _test_linear_tosa_ethosu_BI_pipeline( .export() .check_count({"torch.ops.aten.linear.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() + .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() @@ -188,6 +197,7 @@ def test_linear_tosa_MI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -195,6 +205,7 @@ def test_linear_tosa_MI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), test_data, ) @@ -205,11 +216,15 @@ def test_linear_tosa_BI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) self._test_linear_tosa_BI_pipeline( - self.Linear(in_features=in_features, out_features=out_features), test_data + self.Linear( + in_features=in_features, out_features=out_features, bias=has_bias + ), + test_data, ) @parameterized.expand(test_data_suite_rank1) @@ -218,6 +233,7 @@ def test_linear_tosa_u55_BI( test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -225,20 +241,22 @@ def test_linear_tosa_u55_BI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), - common.get_u55_compile_spec(permute_memory_to_nhwc=False), + common.get_u55_compile_spec(), test_data, ) if common.is_option_enabled("corstone300"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - @parameterized.expand(test_data_suite_rank1) + @parameterized.expand(test_data_suite_rank1 + test_data_suite_rank4) def test_linear_tosa_u85_BI( self, test_name: str, test_data: torch.Tensor, out_features: int, + has_bias: bool, ): in_features = test_data.shape[-1] test_data = (test_data,) @@ -246,7 +264,8 @@ def test_linear_tosa_u85_BI( self.Linear( in_features=in_features, out_features=out_features, + bias=has_bias, ), - common.get_u85_compile_spec(permute_memory_to_nhwc=False), + common.get_u85_compile_spec(), test_data, ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 096bc2b22f..14a9d1df41 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -39,10 +39,11 @@ from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info -from executorch.exir import EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager from executorch.exir.backend.compile_spec_schema import CompileSpec - +from executorch.exir.backend.partitioner import Partitioner from executorch.exir.lowered_backend_module import LoweredBackendModule + from tabulate import tabulate from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec from torch.fx import Graph @@ -50,50 +51,61 @@ logger = logging.getLogger(__name__) +def _dump_lowered_modules_artifact( + path_to_dump: Optional[str], + artifact: ExecutorchProgramManager, + graph_module: torch.fx.GraphModule, +): + output = "Formated Graph Signature:\n" + output += _format_export_graph_signature( + artifact.exported_program().graph_signature + ) + + def get_output_format(lowered_module) -> str | None: + for spec in lowered_module.compile_specs: + if spec.key == "output_format": + return spec.value.decode() + return None + + for node in graph_module.graph.nodes: + if node.op == "get_attr" and node.name.startswith("lowered_module_"): + lowered_module = getattr(graph_module, node.name) + assert isinstance( + lowered_module, LoweredBackendModule + ), f"Attribute {node.name} must be of type LoweredBackendModule." + + output_format = get_output_format(lowered_module) + if output_format == "tosa": + tosa_fb = lowered_module.processed_bytes + to_print = dbg_tosa_fb_to_json(tosa_fb) + to_print = pformat(to_print, compact=True, indent=1) + output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" + elif output_format == "vela": + vela_cmd_stream = lowered_module.processed_bytes + output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" + else: + logger.warning( + f"No TOSA nor Vela compile spec found in compile specs of {node.name}." + ) + continue + + if not output: + logger.warning("No output to print generated from artifact.") + return + + _dump_str(output, path_to_dump) + + class Partition(tester.Partition): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) + _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) - output = "Formated Graph Signature:\n" - output += _format_export_graph_signature( - self.artifact.exported_program().graph_signature - ) - - def get_output_format(lowered_module) -> str | None: - for spec in lowered_module.compile_specs: - if spec.key == "output_format": - return spec.value.decode() - return None - - for node in self.graph_module.graph.nodes: - if node.op == "get_attr" and node.name.startswith("lowered_module_"): - lowered_module = getattr(self.graph_module, node.name) - assert isinstance( - lowered_module, LoweredBackendModule - ), f"Attribute {node.name} must be of type LoweredBackendModule." - - output_format = get_output_format(lowered_module) - if output_format == "tosa": - tosa_fb = lowered_module.processed_bytes - to_print = dbg_tosa_fb_to_json(tosa_fb) - to_print = pformat(to_print, compact=True, indent=1) - output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" - elif output_format == "vela": - vela_cmd_stream = lowered_module.processed_bytes - output += ( - f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" - ) - else: - logger.warning( - f"No TOSA nor Vela compile spec found in compile specs of {node.name}." - ) - continue - if not output: - logger.warning("No output to print generated from artifact.") - return - - _dump_str(output, path_to_dump) +class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): + def dump_artifact(self, path_to_dump: Optional[str]): + super().dump_artifact(path_to_dump) + _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) class Serialize(tester.Serialize): @@ -211,6 +223,26 @@ def partition(self, partition_stage: Optional[Partition] = None): partition_stage = Partition(arm_partitioner) return super().partition(partition_stage) + def to_edge_transform_and_lower( + self, + to_edge_and_lower_stage: Optional[ToEdgeTransformAndLower] = None, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + ): + if to_edge_and_lower_stage is None: + if partitioners is None: + partitioners = [ArmPartitioner(compile_spec=self.compile_spec)] + to_edge_and_lower_stage = ToEdgeTransformAndLower( + partitioners, edge_compile_config + ) + else: + if partitioners is not None: + to_edge_and_lower_stage.partitioners = partitioners + if edge_compile_config is not None: + to_edge_and_lower_stage.edge_compile_conf = edge_compile_config + to_edge_and_lower_stage.edge_compile_conf._skip_dim_order = True + return super().to_edge_transform_and_lower(to_edge_and_lower_stage) + def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None): if to_executorch_stage is None: to_executorch_stage = ToExecutorch(self.runner_util) @@ -255,21 +287,23 @@ def run_method_and_compare_outputs( inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data. The default is random data. """ + + edge_stage = self.stages[self.stage_name(tester.ToEdge)] + if edge_stage is None: + edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)] assert ( self.runner_util is not None ), "self.tosa_test_util is not initialized, cannot use run_method()" assert ( - self.stages[self.stage_name(tester.ToEdge)] is not None - ), "To compare outputs, at least the ToEdge stage needs to be run." + edge_stage is not None + ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." stage = stage or self.cur test_stage = self.stages[stage] is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None exported_program = self.stages[self.stage_name(tester.Export)].artifact - edge_program = self.stages[ - self.stage_name(tester.ToEdge) - ].artifact.exported_program() + edge_program = edge_stage.artifact.exported_program() self.runner_util.init_run( exported_program, edge_program, @@ -333,8 +367,10 @@ def get_graph(self, stage: str | None = None) -> Graph: if stage is None: stage = self.cur artifact = self.get_artifact(stage) - if self.cur == self.stage_name(tester.ToEdge) or self.cur == self.stage_name( - Partition + if ( + self.cur == self.stage_name(tester.ToEdge) + or self.cur == self.stage_name(Partition) + or self.cur == self.stage_name(ToEdgeTransformAndLower) ): graph = artifact.exported_program().graph elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name( @@ -362,7 +398,14 @@ def dump_operator_distribution( line = "#" * 10 to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n" - if self.cur == self.stage_name(tester.Partition) and print_table: + if ( + self.cur + in ( + self.stage_name(tester.Partition), + self.stage_name(ToEdgeTransformAndLower), + ) + and print_table + ): graph_module = self.get_artifact().exported_program().graph_module if print_table: delegation_info = get_delegation_info(graph_module) diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index bf60aaf0f8..c91d89b1b9 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -131,31 +131,6 @@ def get_output_node(node: Node) -> Node: return list(node.users)[0] -# Helper function to do broadcasting -# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_broadcasting -def broadcast_shapes(shape1, shape2): - assert len(shape1) == len(shape2), "broadcast_shapes::shapes must have same ranks" - - need_broadcasting = False - for val1, val2 in zip(shape1, shape2): - if val1 != val2: - need_broadcasting = True - if not need_broadcasting: - return shape1 - - broadcasted_shape = list(shape1) - shape2 = list(shape2) - for idx, _ in enumerate(broadcasted_shape): - if broadcasted_shape[idx] == 1: - broadcasted_shape[idx] = shape2[idx] - else: - assert not ( - shape2[idx] != 1 and shape2[idx] != broadcasted_shape[idx] - ), "broadcast_shapes::broadcast shape mismatch" - - return broadcasted_shape - - """ TOSA reshape returns a tensor with the same type/values as the input. No data conversion happens during a reshape operation. """ @@ -166,36 +141,6 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) -def is_permute_node_before_addmm(node): - return ( - node.target == exir_ops.edge.aten.permute_copy.default - and list(node.users)[0].target == exir_ops.edge.aten.addmm.default - ) - - -def is_bias_node_for_quantized_addmm(node): - consumer_node = list(node.users)[0] - # consumer node is addmm - is_rank2_linear_bias = ( - consumer_node.target == exir_ops.edge.aten.addmm.default - and list(consumer_node.users)[0].target == q_op - ) - - # rank>2 linear layers - # consumer_consumer node is view_copy - is_rank_greater_than_2_linear_bias = False - if ( - consumer_node.target == exir_ops.edge.aten.addmm.default - and list(consumer_node.users)[0].target == exir_ops.edge.aten.view_copy.default - ): - consumer_consumer_node = list(consumer_node.users)[0] - is_rank_greater_than_2_linear_bias = ( - list(consumer_consumer_node.users)[0].target == q_op - ) - - return is_rank2_linear_bias or is_rank_greater_than_2_linear_bias - - def is_bias_node_for_quantized_conv(node): consumer_node = list(node.users)[0] return ( @@ -301,11 +246,7 @@ def process_call_function( tosa_graph.currRegion.currBasicBlock.addTensor( output.name, - ( - tosa_shape(inputs[0].shape, inputs[0].dim_order) - if is_permute_node_before_addmm(node) - else tosa_shape(output.shape, output.dim_order) - ), + (tosa_shape(output.shape, output.dim_order)), map_dtype(get_quant_node_dtype(node)) if is_quant_node(node) else output.dtype, ) From 793f17e52f034ee8823a858e5f0eac2c315942fa Mon Sep 17 00:00:00 2001 From: Gasoonjia Date: Mon, 11 Nov 2024 02:37:02 -0800 Subject: [PATCH 42/59] introduce slack channel for community Differential Revision: D65609750 Pull Request resolved: https://github.com/pytorch/executorch/pull/6717 --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index da2cb82ef9..aded66bf40 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,11 @@ We recommend using the latest release tag from the See [CONTRIBUTING.md](CONTRIBUTING.md) for details about issues, PRs, code style, CI jobs, and other development topics. +To connect with us and other community members, we invite you to join PyTorch Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can: +* Head to the `#executorch-general` channel for general questions, discussion, and community support. +* Join the `#executorch-contributors` channel if you're interested in contributing directly to project development. + + ## Directory Structure ``` From 7fcd0af54663b3aabfe92e6a5bee42a27b59478b Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 11 Nov 2024 14:43:48 +0100 Subject: [PATCH 43/59] Search graph for quantization parameters (#6690) * Search graph for quantization nodes Generalizes the search for quantization parameters. The idea is to make a graph like this a valid quantized graph: dq -> view -> transpose -> some_op ^ / dq ------> expand -------/ For a subset of operations 'passable_op' it is is allowed to "pass through" the op when searching for qparams. If multiple qparams are encounterd in one search, they are asserted to be equal. Signed-off-by: Erik Lundell --- .../annotate_channels_last_dim_order_pass.py | 5 +- .../_passes/insert_squeeze_after_sum_pass.py | 14 +- .../arm/_passes/size_adjust_conv2d_pass.py | 4 +- backends/arm/operators/op_bmm.py | 16 +- backends/arm/operators/op_conv2d.py | 20 +- backends/arm/operators/op_exp.py | 7 +- backends/arm/operators/op_full.py | 11 +- backends/arm/operators/op_hardtanh.py | 13 +- backends/arm/operators/op_log.py | 7 +- backends/arm/operators/op_max_pool2d.py | 9 +- backends/arm/operators/op_mm.py | 16 +- backends/arm/operators/op_mul.py | 4 +- backends/arm/operators/op_placeholder.py | 17 +- backends/arm/operators/op_reciprocal.py | 7 +- backends/arm/operators/op_relu.py | 2 +- backends/arm/operators/op_rsqrt.py | 7 +- backends/arm/operators/op_sigmoid.py | 7 +- backends/arm/operators/op_tanh.py | 7 +- backends/arm/quantizer/arm_quantizer.py | 2 +- .../generic_annotator.py | 3 + .../quantization_annotation/mm_annotator.py | 4 +- backends/arm/test/ops/test_bmm.py | 20 +- backends/arm/test/ops/test_linear.py | 2 +- backends/arm/tosa_quant_utils.py | 236 +++++++++++++----- backends/arm/tosa_utils.py | 20 +- backends/arm/util/arm_model_evaluator.py | 15 +- 26 files changed, 317 insertions(+), 158 deletions(-) diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 77def9e7cd..786117e645 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -14,7 +14,7 @@ get_first_fake_tensor, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs): return args[0] +register_passable_op(torch.ops.passthrough_to_tosa._transpose) + + class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/_passes/insert_squeeze_after_sum_pass.py b/backends/arm/_passes/insert_squeeze_after_sum_pass.py index 152d5c95f6..adf2b4f491 100644 --- a/backends/arm/_passes/insert_squeeze_after_sum_pass.py +++ b/backends/arm/_passes/insert_squeeze_after_sum_pass.py @@ -8,9 +8,7 @@ import torch import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair - -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node +from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass): sum(dims, keep_dim = False) After pass: sum(dims, keep_dim = True) - (q) - (dq) squeeze(dim = dims) """ @@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule): continue dim_list = cast(list[int], sum_node.args[1]) - quantized = is_quant_node(sum_node) - if quantized: - qparams = get_quant_node_args(sum_node.all_input_nodes[0]) - qparams = qparams + (torch.int8,) - else: - qparams = None # Add keep_dim = True arg to sum node. sum_node.args = sum_node.args[0:2] + (True,) @@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule): ) sum_node.replace_all_uses_with(squeeze_node) squeeze_node.args = (sum_node, dim_list) - if quantized: - sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams) graph_module.graph.eliminate_dead_code() graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/_passes/size_adjust_conv2d_pass.py b/backends/arm/_passes/size_adjust_conv2d_pass.py index 980ab09e59..c7bd27dcce 100644 --- a/backends/arm/_passes/size_adjust_conv2d_pass.py +++ b/backends/arm/_passes/size_adjust_conv2d_pass.py @@ -9,7 +9,7 @@ from typing import cast, Optional import torch.fx -from executorch.backends.arm.tosa_quant_utils import is_quant_node +from executorch.backends.arm.tosa_quant_utils import is_node_quantized from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch._ops import OpOverload @@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule): slice_node = graph.create_node( "call_function", self.slice_op, (last_node,) + args ) - if is_quant_node(last_node): + if is_node_quantized(last_node): q_params = last_node.args[1:] dq_node = insert_q_dq_pair( graph_module.graph, slice_node, q_params diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 161b5d2239..8c9bd7ac2a 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -14,7 +14,11 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + get_quant_arg_downstream, + get_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import get_two_inputs from serializer.tosa_serializer import TosaOp @@ -42,8 +46,10 @@ def define_node( # For INT8, we need to get the zero points and add an intermediate tensor # for a later rescale. if is_quant_node: - input0_zp = get_quant_node_args(input0).zp - input1_zp = get_quant_node_args(input1).zp + input0_q_params = get_quant_arg_upstream(input0) + input1_q_params = get_quant_arg_upstream(input1) + input0_zp = input0_q_params.zp + input1_zp = input1_q_params.zp bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) bmm_output_name = bmm_result.name else: @@ -63,9 +69,7 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = get_quant_node_args(input0) - input1_q_params = get_quant_node_args(input1) - output_q_params = get_quant_node_args(list(node.users)[0]) + output_q_params = get_quant_arg_downstream(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 64cde0724f..ffbeee7306 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import cast, List +from typing import List import serializer.tosa_serializer as ts import torch @@ -15,9 +15,10 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( build_rescale_conv_output, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, ) -from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape +from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -82,7 +83,7 @@ def define_node( ) input_zp = ( - get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0 + get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 ) attr.ConvAttribute( @@ -158,9 +159,10 @@ def define_node( # integer value domain of the next op. Otherwise return float32 output. if is_quant_node: # Get scale_factor from input, weight, and output. - _, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0])) - _, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1])) - _, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0]) + input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale + weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale + output_qargs = get_quant_arg_downstream(list(node.users)[0]) + build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. @@ -169,6 +171,6 @@ def define_node( actual_out_type, input_scale, weight_scale, - output_scale, - output_zp, + output_qargs.scale, + output_qargs.zp, ) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 0e0a75dcc4..7a0b4e104f 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -48,9 +49,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = exp_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index cf67975e0d..d2bc1377ce 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,7 +14,10 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_downstream, + quantize_value, +) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -39,10 +42,8 @@ def define_node( value = inputs[1].number if is_quant_node: - qargs = get_quant_node_args(list(node.users)[0]) - qvalue = np.clip( - np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax - ) + qargs = get_quant_arg_downstream(list(node.users)[0]) + qvalue = quantize_value(value, qargs) dtype = ts.DType.INT8 data = np.full(shape, qvalue, dtype=np.int8) else: diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index 62c0a27f05..e726028206 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -14,7 +14,10 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_upstream, + quantize_value, +) from serializer.tosa_serializer import TosaOp @@ -37,12 +40,10 @@ def define_node( if is_quant_node: # Get quant parameters - scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0]) + qargs = get_quant_arg_upstream(node.all_input_nodes[0]) # Convert to quantized representation - clamp_min_qs = round((inputs[1].number / scale) + zp) - clamp_min_qs = max(clamp_min_qs, qmin) - clamp_max_qs = round((inputs[2].number / scale) + zp) - clamp_max_qs = min(clamp_max_qs, qmax) + clamp_min_qs = quantize_value(inputs[1].number, qargs) + clamp_max_qs = quantize_value(inputs[2].number, qargs) # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 5276173efa..76adc2325e 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = log_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index a0b868f684..74e33ddb02 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -13,7 +13,10 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import get_quant_node_args +from executorch.backends.arm.tosa_utils import ( + get_quant_arg_downstream, + get_quant_arg_upstream, +) from serializer.tosa_serializer import TosaOp @@ -54,8 +57,8 @@ def define_node( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args(node.all_input_nodes[0]).zp - output_zp = get_quant_node_args(list(node.users)[0]).zp + input_zp = get_quant_arg_upstream(node.all_input_nodes[0]).zp + output_zp = get_quant_arg_downstream(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index ebddb3a40e..81334de16c 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -14,7 +14,11 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + get_quant_arg_downstream, + get_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import ( build_reshape, expand_dims, @@ -54,8 +58,8 @@ def define_node( # For INT8, we need to get the zero point, otherwise it is 0 input0_zp, input1_zp = 0, 0 if is_quant_node: - input0_zp = get_quant_node_args(input0).zp - input1_zp = get_quant_node_args(input1).zp + input0_zp = get_quant_arg_upstream(input0).zp + input1_zp = get_quant_arg_upstream(input1).zp mat_mul_result = tosa_graph.addIntermediate( output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype @@ -86,9 +90,9 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = get_quant_node_args(input0) - input1_q_params = get_quant_node_args(input1) - output_q_params = get_quant_node_args(list(node.users)[0]) + input0_q_params = get_quant_arg_upstream(input0) + input1_q_params = get_quant_arg_upstream(input1) + output_q_params = get_quant_arg_downstream(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index c152e8759e..ad578aa1f0 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -37,10 +37,10 @@ def define_node( if is_quant_node: input_A = inputs[0] input_B = inputs[1] - input_A_qargs = tqutils.get_quant_node_args( + input_A_qargs = tqutils.get_quant_arg_upstream( cast(torch.fx.Node, node.args[0]) ) - input_B_qargs = tqutils.get_quant_node_args( + input_B_qargs = tqutils.get_quant_arg_upstream( cast(torch.fx.Node, node.args[1]) ) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 950d4636d2..d466a13e38 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -10,13 +10,14 @@ import torch.fx from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_dtype, - get_quant_node_args, - is_quant_arg, + get_quant_arg_upstream, + get_quantized_node_output_dtype, + is_node_quantized, ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import ( is_bias_node_for_quantized_conv, + map_dtype, tosa_shape, ) from torch.export.exported_program import ExportedProgram @@ -41,7 +42,11 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype, + ( + map_dtype(get_quantized_node_output_dtype(node)) + if is_node_quantized(node) + else inputs[0].dtype + ), data=None, placeholderFilename=inputs[0].name + ".npy", ) @@ -63,8 +68,8 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_node_scale = get_quant_node_args(input_node).scale - weight_node_scale = get_quant_node_args(weight_node).scale + input_node_scale = get_quant_arg_upstream(input_node).scale + weight_node_scale = get_quant_arg_upstream(weight_node).scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 3d43fd8f7d..774c4d94b1 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -15,7 +15,8 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -41,8 +42,8 @@ def define_node( if is_quant_node: input = inputs[0] - input_qargs = get_quant_node_args(node.all_input_nodes[0]) - output_qargs = get_quant_node_args(list(node.users)[0]) + input_qargs = get_quant_arg_upstream(node.all_input_nodes[0]) + output_qargs = get_quant_arg_downstream(list(node.users)[0]) div_table = div_table_8bit(input_qargs, output_qargs) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index 20bba3f654..a3a7c82ab8 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -38,7 +38,7 @@ def define_node( clamp_min_qs = 0 clamp_max_qs = 0 if is_quant_node: - out_qargs = tqutils.get_quant_node_args(list(node.users)[0]) + out_qargs = tqutils.get_quant_arg_downstream(list(node.users)[0]) clamp_min_qs = tqutils.quantize_value(0, out_qargs) clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 9225c7d938..b503a323b1 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -16,7 +16,8 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -39,9 +40,9 @@ def define_node( # Assume quantized input is 8 bit. # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = rsqrt_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() table_attr.TableAttribute(table) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 0087b1f7a8..e299e99b43 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = sigmoid_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 20f343a7f1..2c84580edc 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -17,7 +17,8 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, + get_quant_arg_downstream, + get_quant_arg_upstream, QuantArgs, quantize_value, ) @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = get_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = get_quant_arg_downstream(output_node) table = tanh_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index e61fbc5bbe..511aeda1ac 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -75,7 +75,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern [torch.nn.AdaptiveAvgPool2d], [F.adaptive_avg_pool2d], ], - "mul": [torch.mul], + "mul": [[torch.mul]], "sub": [[torch.sub]], } return copy.deepcopy(supported_operators) diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py index 126051f158..b093eec808 100644 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/generic_annotator.py @@ -29,6 +29,9 @@ torch.ops.aten.unsqueeze.default, torch.ops.aten.unsqueeze_copy.default, torch.ops.aten.reshape.default, + torch.ops.aten.repeat.default, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand.default, # Disabling these as there seems to be an issue with support for complex # datatypes in torch: # torch.ops.aten.view_as_complex.default, diff --git a/backends/arm/quantizer/quantization_annotation/mm_annotator.py b/backends/arm/quantizer/quantization_annotation/mm_annotator.py index b48c6d5990..60d9adb1c3 100644 --- a/backends/arm/quantizer/quantization_annotation/mm_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/mm_annotator.py @@ -24,7 +24,9 @@ def _annotate_mm( quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - mm_partitions = get_source_partitions(gm.graph, [torch.mm, torch.bmm], filter_fn) + mm_partitions = get_source_partitions( + gm.graph, [torch.mm, torch.bmm, torch.matmul], filter_fn + ) mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values())) annotated_partitions = [] for mm_partition in mm_partitions: diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index e5e9508e25..6246657120 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -32,6 +32,12 @@ class BMM(torch.nn.Module): def forward(self, x, y): return torch.bmm(x, y) + class MatMul(torch.nn.Module): + test_parameters = [(torch.rand(2, 3, 5), torch.rand(2, 5, 2))] + + def forward(self, x, y): + return torch.matmul(x, y) + class BMMSingleInput(torch.nn.Module): test_parameters = [ (torch.rand(20, 3, 3),), @@ -53,9 +59,9 @@ def _test_bmm_tosa_MI_pipeline( compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() - .check_count({"torch.ops.aten.bmm.default": 1}) .check_not(["torch.ops.quantized_decomposed"]) .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -74,9 +80,9 @@ def _test_bmm_tosa_BI_pipeline( ) .quantize() .export() - .check_count({"torch.ops.aten.bmm.default": 1}) .check(["torch.ops.quantized_decomposed"]) .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -116,6 +122,16 @@ def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data) + @parameterized.expand(MatMul.test_parameters) + def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) + + @parameterized.expand(MatMul.test_parameters) + def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) + @parameterized.expand(BMM.test_parameters) def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index c7a475035d..30d4b2890a 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -165,7 +165,7 @@ def _test_linear_tosa_BI_pipeline( .to_edge_transform_and_lower(edge_compile_config=self._edge_compile_config) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=True) + .run_method_and_compare_outputs(inputs=test_data, qtol=1) ) def _test_linear_tosa_ethosu_BI_pipeline( diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index fe408e41b3..19397fe6b2 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -8,21 +8,38 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import NamedTuple, Sequence +from typing import Callable, cast, NamedTuple, Sequence import numpy as np import serializer.tosa_serializer as ts import torch.fx import tosa.Op as TosaOp -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaSerializerTensor from torch.fx import Node + q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default -dq_q_ops = [q_op, dq_op] +dq_q_ops = (q_op, dq_op) +passable_ops = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, +] + + +def register_passable_op(op): + """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" + passable_ops.append(op) class QuantArgs(NamedTuple): @@ -30,6 +47,19 @@ class QuantArgs(NamedTuple): zp: int qmin: int qmax: int + dtype: torch.dtype + + def quantize_value(self, x): + if not isinstance(x, torch.Tensor): + x = torch.Tensor([x]) + return torch.clip( + torch.round(x / self.scale) + self.zp, + self.qmin, + self.qmax, + ).to(self.dtype) + + def dequantize_value(self, qx: int) -> float: + return (qx - self.zp) * self.scale def quantize_value(x, qargs: QuantArgs, dtype=np.int8): @@ -44,81 +74,159 @@ def dequantize_value(qx, qargs: QuantArgs): return (qx - qargs.zp) * qargs.scale -def is_quant_node(node: torch.fx.Node): +def qargs_from_qnode(node: torch.fx.Node): + assert node.target in dq_q_ops, f"Op {node} is not a quant node." - consumer_node_condition = False - if len(list(node.users)) > 0: - consumer_node = list(node.users)[0] + return QuantArgs( + scale=cast(float, node.args[1]), + zp=cast(int, node.args[2]), + qmin=cast(int, node.args[3]), + qmax=cast(int, node.args[4]), + dtype=cast(torch.dtype, node.args[5]), + ) - # For Rank > 2 Linear layers, the quant node is after the view_copy - if ( - node.target == exir_ops.edge.aten.addmm.default - and consumer_node.target == exir_ops.edge.aten.view_copy.default - ): - consumer_consumer_node = list(consumer_node.users)[0] - return True if consumer_consumer_node.target == q_op else False - consumer_node_condition = consumer_node.target == q_op - input_node_condition = False - if len(node.all_input_nodes) > 0: - input = node.all_input_nodes[0] - input_node_condition = input.target in dq_q_ops +def get_neighbour_quant_args( + node: torch.fx.Node, +) -> tuple[list[QuantArgs], list[QuantArgs]]: + user_q_args = [] - return node.target in dq_q_ops or consumer_node_condition or input_node_condition + for user in node.users: + q_args = search_quant_arg_downstream(user) + if q_args: + user_q_args.append(q_args) + input_q_nodes = [] + for input_node in node.all_input_nodes: + q_args = search_quant_arg_upstream(input_node) + if q_args: + input_q_nodes.append(q_args) + return user_q_args, input_q_nodes -def get_quant_node_dtype(node: torch.fx.Node): - # pyre-ignore[16]: Undefined attribute. - if "tosa" in node.target.__name__: - return node.meta["val"].dtype - if node.target in dq_q_ops: - return node.args[5] +def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: + first_q_arg = q_arg_list[0] + for q_arg in q_arg_list: + if q_arg != first_q_arg: + return False + return True - # if not a tosa node, nor a q/dq op, walk the graph until we find a q op - consumer_node = list(node.users)[0] - while True: - if consumer_node.target in dq_q_ops: - return consumer_node.args[5] - # Try to move on to the next node - if len(consumer_node.users) == 0: - raise RuntimeError(f"No quantized node found in graph for node {node}") - consumer_node = list(consumer_node.users)[0] +def is_node_quantized(node: torch.fx.Node) -> bool: + if node.target in dq_q_ops: + return True + user_q_args, input_q_args = get_neighbour_quant_args(node) -def is_quant_arg(arg): - consumer_node = list(arg.users)[0] - return consumer_node.target == q_op + # If we did not find any neighbouring quant nodes, we are not quantized. + if len(input_q_args) == 0 and len(user_q_args) == 0: + return False + if node.target in passable_ops: + assert all_q_args_equal( + user_q_args + input_q_args + ), f"Node {node} needs same quantization parameters on all inputs and outputs." -def get_quant_arg_dtype(node: torch.fx.Node): - consumer_node = list(node.users)[0] + return True - # Get type of quant node, args differ from per_tensor and per_channel. - if consumer_node.target == q_op: - if is_quant_arg(node): - return map_dtype(consumer_node.args[5]) - else: - raise RuntimeError("Quantization argument not found") + +def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple consumers is encountered, + find QuantArgs for all consumers and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without consumers is encountered, return None. + """ + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + consumer_nodes = list(node.users) + if len(consumer_nodes) == 0: + return None + elif len(consumer_nodes) == 1: + return search_quant_arg_downstream(consumer_nodes[0]) + else: + consumer_qargs: list[QuantArgs] = [] + for input in consumer_nodes: + quant_args = search_quant_arg_downstream(input) + if quant_args: + consumer_qargs.append(quant_args) + if len(consumer_qargs) == 0: + return None + assert all_q_args_equal( + consumer_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." + return consumer_qargs[0] + + +def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_downstream and asserts that QuantArgs are found, + meaning return value can't be None. + """ + qargs = search_quant_arg_downstream(node) + assert qargs, f"Did not find QuantArgs downstream for node {node}" + return qargs -def get_quant_node_args(node: torch.fx.Node): +def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple inputs is encountered, + find QuantArgs for all inputs and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without inputs is encountered, return None. """ - Get the quantization parameters from a quant node. - Args: - node: The quant node. - Returns: - QuantArgs: scale, zp, qmin, qmax + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + input_nodes = list(node.all_input_nodes) + if len(input_nodes) == 0: + return None + elif len(input_nodes) == 1: + return search_quant_arg_upstream(input_nodes[0]) + else: + input_qargs: list[QuantArgs] = [] + for input in input_nodes: + quant_args = search_quant_arg_upstream(input) + if quant_args: + input_qargs.append(quant_args) + if len(input_qargs) == 0: + return None + assert all_q_args_equal( + input_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." + return input_qargs[0] + + +def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_upstream and asserts that QuantArgs are found, + meaning return value can't be None. """ - quant_args = [TosaArg(arg) for arg in node.args] - return QuantArgs( - quant_args[1].number, - quant_args[2].number, - quant_args[3].number, - quant_args[4].number, - ) + qargs = search_quant_arg_upstream(node) + assert qargs, f"Did not find QuantArgs upstream for node {node}" + return qargs + + +def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: + if isinstance(node.target, Callable) and "tosa" in node.target.__name__: + return node.meta["val"].dtype + if node.target in dq_q_ops: + return cast(torch.dtype, node.args[5]) + + # if not a tosa node, nor a q/dq op, walk the graph until we find a q op + user_q_args, input_q_args = get_neighbour_quant_args(node) + if len(user_q_args) > 0: + return user_q_args[0].dtype + elif node.target in passable_ops and len(input_q_args) > 0: + return input_q_args[0].dtype + else: + raise RuntimeError("No quantized node found in graph") # Check if scale32 mode is used for given output element type @@ -267,14 +375,14 @@ def rescale_nodes_to_int32( needed by rescale_node_back_to_int8. """ - tensors = [TosaArg(node.args[0]) for node in nodes] + tensors = [TosaArg(node) for node in nodes] # Reshape tensor according to tosa dim order for tensor in tensors: dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] - qargs = [get_quant_node_args(node) for node in nodes] + qargs = [get_quant_arg_upstream(node) for node in nodes] # Scale the int8 quantized input to a common scale in the integer # domain @@ -307,7 +415,7 @@ def rescale_node_back_to_int8( scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' tosa_graph: the tosa_graph to manipulate. """ - qargs_out = get_quant_node_args(list(node.users)[0]) + qargs_out = get_quant_arg_downstream(list(node.users)[0]) output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 @@ -334,7 +442,7 @@ def build_rescale_conv_output( output_zp, ): # TODO add check to verify if this is a Per-channel quantization. - post_conv2d_scale = (input_scale.number * weight_scale.number) / output_scale.number + post_conv2d_scale = (input_scale * weight_scale) / output_scale # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. build_rescale( @@ -345,6 +453,6 @@ def build_rescale_conv_output( output_type, op.shape, 0, - output_zp.number, + output_zp, ) return diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index c91d89b1b9..b61b27853a 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -16,9 +16,10 @@ from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_node_args, - get_quant_node_dtype, - is_quant_node, + get_quant_arg_downstream, + get_quant_arg_upstream, + get_quantized_node_output_dtype, + is_node_quantized, q_op, ) from executorch.backends.arm.tosa_specification import TosaSpecification @@ -183,8 +184,8 @@ def build_avg_pool_2d_common( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args(cast(torch.fx.Node, node.args[0])).zp - output_zp = get_quant_node_args(list(node.users)[0]).zp + input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp + output_zp = get_quant_arg_downstream(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( @@ -244,10 +245,15 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) + is_quant_node = is_node_quantized(node) + if is_quant_node: + output_dtype = map_dtype(get_quantized_node_output_dtype(node)) + else: + output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( output.name, (tosa_shape(output.shape, output.dim_order)), - map_dtype(get_quant_node_dtype(node)) if is_quant_node(node) else output.dtype, + output_dtype, ) # Visiting each Node @@ -259,7 +265,7 @@ def process_call_function( tosa_graph, inputs, output, - is_quant_node(node), + is_quant_node, ) else: raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py index 4ffb80c2f0..b348f10722 100644 --- a/backends/arm/util/arm_model_evaluator.py +++ b/backends/arm/util/arm_model_evaluator.py @@ -7,7 +7,7 @@ import os import tempfile import zipfile -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple import torch @@ -32,7 +32,7 @@ def __init__( else: self.tosa_output_path = None - def get_model_error(self) -> Union[float, float, float, float]: + def get_model_error(self) -> tuple[float, float, float, float]: """ Returns the following metrics between the outputs of the FP32 and INT8 model: - Maximum error @@ -51,7 +51,12 @@ def get_model_error(self) -> Union[float, float, float, float]: max_percentage_error = torch.max(percentage_error).item() mean_absolute_error = torch.mean(torch.abs(difference).float()).item() - return max_error, max_absolute_error, max_percentage_error, mean_absolute_error + return ( + float(max_error), + float(max_absolute_error), + float(max_percentage_error), + float(mean_absolute_error), + ) def get_compression_ratio(self) -> float: """Compute the compression ratio of the outputted TOSA flatbuffer.""" @@ -67,7 +72,7 @@ def get_compression_ratio(self) -> float: return compression_ratio - def evaluate(self) -> dict[any]: + def evaluate(self) -> dict[str, Any]: max_error, max_absolute_error, max_percent_error, mean_absolute_error = ( self.get_model_error() ) @@ -82,6 +87,8 @@ def evaluate(self) -> dict[any]: } if self.tosa_output_path: + # We know output_metrics["metrics"] is list since we just defined it, safe to ignore. + # pyre-ignore[16] output_metrics["metrics"][ "compression_ratio" ] = self.get_compression_ratio() From 146ca1ba547ad1128224d83eeebc22f915d11e8f Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 11 Nov 2024 13:48:52 -0500 Subject: [PATCH 44/59] Swap mha (#6719) * Swap mha Move to extension/llm/modules Lint Add tests * Fix tests * Delete old file --- extension/llm/modules/README.md | 23 +- extension/llm/modules/mha.py | 404 +++++++++++++++++++++++++ extension/llm/modules/test/test_mha.py | 144 +++++++++ 3 files changed, 561 insertions(+), 10 deletions(-) create mode 100644 extension/llm/modules/mha.py create mode 100644 extension/llm/modules/test/test_mha.py diff --git a/extension/llm/modules/README.md b/extension/llm/modules/README.md index 3694f8b155..e6e1a20cec 100644 --- a/extension/llm/modules/README.md +++ b/extension/llm/modules/README.md @@ -1,14 +1,17 @@ -## Export Friendly Modules +## Export-friendly Modules -Modules in this directory are: -* Extending `torch.nn.Module`. -* Guranteed to work out of the box with `torch.export.export()` and `torch.aot_compile()`. -* Guranteed to be able to work with ExecuTorch. +Modules in this directory: +* Extend `torch.nn.Module`. +* Are guaranteed to work out of the box with `torch.export.export()`. +* Should work out of the box with `torch.aot_compile()`. +* Should be able to workt with ExecuTorch. All modules should be covered by unit tests to make sure they are: -1. giving the same output as the reference implementation in PyTorch or torchtune -2. export friendly -3. AOTI friendly -4. ExecuTorch friendly +1. Give the output as the reference eager model in PyTorch or TorrchTune +2. Export-friendly -Notice that these modules are subject to change (may upstream to torchtune) so proceed with caution. +Additionally, we aim to make these modules: +3. AOTI-friendly +4. ExecuTorch-friendly + +These modules are subject to change (may upstream to TorchTune) so proceed with caution. diff --git a/extension/llm/modules/mha.py b/extension/llm/modules/mha.py new file mode 100644 index 0000000000..0bfa4eb20c --- /dev/null +++ b/extension/llm/modules/mha.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torchtune.modules.attention as TorchTuneAttention +from torch import nn +from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention +from torchtune.modules.kv_cache import KVCache + +logger = logging.getLogger(__name__) + + +class MultiHeadAttention(nn.Module): + """ + NOTE: copied from Torchtune's mha.py. Should be mostly 1:1 except + that SDPA is factored out so that it can be swapped for more + efficient ExecuTorch-defined SDPA ops. + + Multi-headed attention layer with support for grouped query + attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. + + GQA is a version of multiheaded attention (MHA) which uses fewer + key/value heads than query heads by grouping n query heads for each + key and value head. Multi-Query Attention is an extreme + version where we have a single key and value head shared by all + query heads. + + Following is an example of MHA, GQA and MQA with num_heads = 4 + + (credit for the documentation: + `litgpt.Config `_). + + + :: + + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ │ │ │ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + MHA GQA MQA + n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 + + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. + Default value is 0.0. + + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # Use flex attention if supported and we are sample packing + self._attention_call = _sdpa_or_flex_attention() + self._sdpa = SDPA( + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + head_dim=self.head_dim, + q_per_kv=self.num_heads // self.num_kv_heads, + attn_dropout=self.attn_dropout if self.training else 0.0, + is_causal=self.is_causal, + attention_fn=self._attention_call, + kv_cache=self.kv_cache, + ) + + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self._sdpa.kv_cache = self.kv_cache + self.cache_enabled = True + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + v = v.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) + + output = self._sdpa(q, k, v, b, s_x) + return self.output_proj(output) + + +class SDPA(nn.Module): + """ + TorchTune's SDPA which can be optimized and can be swapped + out for a more efficient implementations. + """ + + def __init__( + self, + num_kv_heads: int, + num_heads: int, + head_dim: int, + q_per_kv: int, + attn_dropout: float, + is_causal: bool, + attention_fn, + kv_cache, + ) -> None: + super().__init__() + self.num_kv_heads = num_kv_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.q_per_kv = q_per_kv + self.attn_dropout = attn_dropout + self.is_causal = is_causal + self._attention_fn = attention_fn + self.kv_cache = kv_cache + + def forward( + self, + q: torch.Tensor, # [b, s, n_h, h_d] + k: torch.Tensor, # [b, s, n_kv, h_d] + v: torch.Tensor, # [b, s, n_kv, h_d] + bsz: int, + seq_len: int, + mask: torch.Tensor = None, + ) -> torch.Tensor: + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [bsz, seq_len, n_kv, 1, h_d] + # v: [bsz, seq_len, n_kv, 1, h_d] + k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + + # Expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + + # [bsz, s, n_h, h_d] + k = k.reshape(bsz, seq_len, -1, self.head_dim) + v = v.reshape(bsz, seq_len, -1, self.head_dim) + + # [bsz, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output = self._attention_fn( + q, + k, + v, + mask=mask, + dropout_p=self.attn_dropout, + is_causal=self.kv_cache is None and mask is None and self.is_causal, + ) + # Reshape the output to be the same shape as the input + return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + + +def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None: + for name, child in module.named_children(): + if isinstance(child, TorchTuneAttention.MultiHeadAttention): + setattr( + module, + name, + MultiHeadAttention( + embed_dim=child.embed_dim, + num_heads=child.num_heads, + num_kv_heads=child.num_kv_heads, + head_dim=child.head_dim, + q_proj=child.q_proj, + k_proj=child.k_proj, + v_proj=child.v_proj, + output_proj=child.output_proj, + pos_embeddings=child.pos_embeddings, + q_norm=child.q_norm, + k_norm=child.k_norm, + kv_cache=child.kv_cache, + max_seq_len=child.max_seq_len, + is_causal=child.is_causal, + attn_dropout=child.attn_dropout, + ), + ) + else: + replace_mha_with_inference_mha(child) + + +def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: + """ + Replace TorchTune's MHA with an inference friendly version of MHA that + separates out the inference-related parts for further optimization. + """ + _replace_mha_with_inference_mha(module) + return module diff --git a/extension/llm/modules/test/test_mha.py b/extension/llm/modules/test/test_mha.py new file mode 100644 index 0000000000..0dc7cba685 --- /dev/null +++ b/extension/llm/modules/test/test_mha.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge + +from executorch.extension.llm.modules.mha import ( + MultiHeadAttention as ETMultiHeadAttention, +) +from executorch.runtime import Runtime +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention + + +torch.manual_seed(0) + + +class AttentionTest(unittest.TestCase): + def setUp(self): + super().setUp() + + # Constants + self.embed_dim = 2048 + self.num_heads = 32 + self.num_kv_heads = 8 + self.head_dim = 64 + self.max_seq_len = 128 + self.rope_base = 500_000 + self.scale_factor = 32 + + # Module dependency injections. + self.q_proj = torch.nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = torch.nn.Linear( + self.embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.v_proj = torch.nn.Linear( + self.embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.pos_embeddings = Llama3ScaledRoPE( + dim=self.head_dim, + max_seq_len=self.max_seq_len, + base=self.rope_base, + scale_factor=self.scale_factor, + ) + + # Original TorchTune reference module to test accuracy against. + self.tt_mha = TTMultiHeadAttention( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + q_proj=self.q_proj, + k_proj=self.k_proj, + v_proj=self.v_proj, + output_proj=self.output_proj, + pos_embeddings=self.pos_embeddings, + max_seq_len=self.max_seq_len, + ) + + # Source transformed module that we are testing. + self.et_mha = ETMultiHeadAttention( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + q_proj=self.q_proj, + k_proj=self.k_proj, + v_proj=self.v_proj, + output_proj=self.output_proj, + pos_embeddings=self.pos_embeddings, + max_seq_len=self.max_seq_len, + ) + + # Common inputs. + seq_len = 10 + self.x = torch.randn(1, seq_len, self.embed_dim) + seq_len_dim = torch.export.Dim("seq_len", min=1, max=100) + self.dynamic_shapes = ( + {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, + {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, + ) + + def test_attention_eager(self): + et_res = self.et_mha(self.x, self.x) # Self attention. + tt_res = self.tt_mha(self.x, self.x) # Self attention. + + self.assertTrue(torch.allclose(et_res, tt_res)) + + # TODO: KV cache. + # self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20) + # self.tt_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20) + + # et_res = self.et_mha(self.x, self.x) # Self attention. + # tt_res = self.tt_mha(self.x, self.x) # Self attention. + + # self.assertTrue(torch.allclose(et_res, tt_res)) + + def test_attention_export(self): + # Self attention. + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs=None, + dynamic_shapes=self.dynamic_shapes, + ) + et_res = et_mha_ep.module()(self.x, self.x) + tt_res = self.tt_mha(self.x, self.x) + self.assertTrue(torch.allclose(et_res, tt_res)) + + # TODO: KV cache. + + def test_attention_aoti(self): + # TODO. + pass + + def test_attention_executorch(self): + # Self attention. + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs=None, + dynamic_shapes=self.dynamic_shapes, + ) + et_program = to_edge( + et_mha_ep, + compile_config=EdgeCompileConfig(), + ).to_executorch() + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + et_res = method.execute((self.x, self.x)) + tt_res = self.tt_mha(self.x, self.x) + + self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06)) + + # TODO: KV cache. From a809953b3aa6d838d74b27bf2f0514f67172c51d Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:03:24 -0800 Subject: [PATCH 45/59] Add torchao kernels to llama runner Differential Revision: D64942925 Pull Request resolved: https://github.com/pytorch/executorch/pull/6195 --- .gitmodules | 3 + examples/models/llama/CMakeLists.txt | 9 +++ examples/models/llama/export_llama_lib.py | 36 ++++++++++- examples/models/llama/install_requirements.sh | 3 +- .../llama/source_transformation/quantize.py | 61 +++++++++++++++++++ .../llama3_2_vision/install_requirements.sh | 3 +- .../phi-3-mini-lora/install_requirements.sh | 3 +- third-party/ao | 1 + 8 files changed, 110 insertions(+), 9 deletions(-) create mode 160000 third-party/ao diff --git a/.gitmodules b/.gitmodules index 844cd91789..6844743d73 100644 --- a/.gitmodules +++ b/.gitmodules @@ -64,3 +64,6 @@ [submodule "third-party/pybind11"] path = third-party/pybind11 url = https://github.com/pybind/pybind11.git +[submodule "third-party/ao"] + path = third-party/ao + url = https://github.com/pytorch/ao.git diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index b1401a0bca..6a4aee11d2 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -37,6 +37,8 @@ cmake_dependent_option( "NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF ) +option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF) + if(NOT PYTHON_EXECUTABLE) set(PYTHON_EXECUTABLE python3) endif() @@ -121,6 +123,13 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM) list(APPEND link_libraries custom_ops) endif() +if(EXECUTORCH_BUILD_TORCHAO) + set(TORCHAO_BUILD_EXECUTORCH_OPS ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental) + target_link_options_shared_lib(torchao_ops_executorch) + list(APPEND link_libraries torchao_ops_executorch) +endif() + set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack) # Extra compile option and include dir for pthreadpool if(EXECUTORCH_BUILD_PTHREADPOOL) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 23b3589c2a..e1a8d1d06b 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -12,6 +12,7 @@ import copy import json import logging +import re import shlex from enum import Enum from json import JSONDecodeError @@ -19,7 +20,6 @@ from typing import Callable, List, Optional, Union import pkg_resources - import torch from executorch.devtools.etrecord import generate_etrecord @@ -153,12 +153,12 @@ def build_args_parser() -> argparse.ArgumentParser: ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) + parser.add_argument( "-qmode", "--quantization_mode", - type=str, + type=_qmode_type, default=None, - choices=["int8", "8da4w", "8da4w-gptq", "vulkan_4w"], help="type of quantization", ) @@ -568,6 +568,23 @@ def get_quantizer_and_quant_params(args): return pt2e_quant_params, quantizers, quant_dtype +def _qmode_type(value): + choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] + patterns = [r"torchao:8da(\d+)w"] + + if value in choices: + return value + + for pattern in patterns: + matches = re.findall(pattern, value) + if len(matches) == 1: + return value + + raise argparse.ArgumentTypeError( + f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}." + ) + + def _validate_args(args): """ TODO: Combine all the backends under --backend args @@ -581,6 +598,19 @@ def _validate_args(args): if args.num_sharding > 0 and not args.qnn: raise ValueError("Model shard is only supported with qnn backend now.") + if ( + args.quantization_mode is not None + and args.quantization_mode.startswith("torchao:") + ) or ( + args.embedding_quantize is not None + and args.embedding_quantize.startswith("torchao:") + ): + if args.enable_dynamic_shape: + raise ValueError( + "Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape." + "If you need this feature, please file an issue." + ) + def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 _validate_args(args) diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index 3103daeb7d..f794b660bd 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -10,8 +10,7 @@ pip install snakeviz sentencepiece # Install torchao. -TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt) -pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}" +pip install "$(dirname "$0")/../../../third-party/ao" # Install lm-eval for Model Evaluation with lm-evalution-harness # Install tiktoken for tokenizer diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 162d41d659..d168b7efcd 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +import re from functools import partial from pathlib import Path from typing import Any, Dict, Optional @@ -70,6 +72,26 @@ def quantize( # noqa C901 if qmode == "int8": # Add quantization mode options here: group size, bit width, etc. return WeightOnlyInt8QuantHandler(model).quantized_model() + elif qmode.startswith("torchao:"): + pattern = r"torchao:8da(\d+)w" + matches = re.findall(pattern, qmode) + assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}" + bitwidth = int(matches[0][0]) + _load_torchao_ops_aten() + from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer + + with torch.no_grad(): + model = Int8DynActIntxWeightLinearQuantizer( + device="cpu", + precision=torch.float32, + groupsize=group_size, + bitwidth=bitwidth, + has_weight_zeros=False, + ).quantize(model) + + if verbose: + print("quantized model:", model) + return model elif qmode == "8da4w": # Check for required args if group_size is None: @@ -79,6 +101,7 @@ def quantize( # noqa C901 model = Int8DynActInt4WeightQuantizer( precision=torch_dtype, groupsize=group_size ).quantize(model) + if verbose: print("quantized model:", model) return model @@ -692,6 +715,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: def get_quant_embedding_transform(args): + if args.embedding_quantize.startswith("torchao:"): + bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",") + group_size = int(group_size) + bitwidth = int(bitwidth) + _load_torchao_ops_aten() + from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer + + def _torchao_embedding_quantizer(model): + with torch.no_grad(): + model = IntxWeightEmbeddingQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=bitwidth, + groupsize=group_size, + ).quantize(model) + return model + + return _torchao_embedding_quantizer + bitwidth, group_size = args.embedding_quantize.split(",") if group_size == "none" or group_size == "None" or group_size == "0": group_size = None @@ -733,4 +775,23 @@ def get_quant_weight_transform(args, dtype_override, verbose): ) +def _load_torchao_ops_aten(): + import glob + import os + + libs = glob.glob( + os.path.abspath( + os.path.join( + os.environ.get("CMAKE_INSTALL_PREFIX", ""), + "lib/libtorchao_ops_aten.*", + ) + ) + ) + assert ( + len(libs) == 1 + ), f"Expected 1 library but got {len(libs)}. If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly." + logging.info(f"Loading custom ops library: {libs[0]}") + torch.ops.load_library(libs[0]) + + ############################ Source Transform End ####################### diff --git a/examples/models/llama3_2_vision/install_requirements.sh b/examples/models/llama3_2_vision/install_requirements.sh index 44cc399acb..49558952d8 100755 --- a/examples/models/llama3_2_vision/install_requirements.sh +++ b/examples/models/llama3_2_vision/install_requirements.sh @@ -9,5 +9,4 @@ pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir # Install torchao. -TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt) -pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}" +pip install "$(dirname "$0")/../../../third-party/ao" diff --git a/examples/models/phi-3-mini-lora/install_requirements.sh b/examples/models/phi-3-mini-lora/install_requirements.sh index ec6289a126..2cd74d0cd4 100755 --- a/examples/models/phi-3-mini-lora/install_requirements.sh +++ b/examples/models/phi-3-mini-lora/install_requirements.sh @@ -10,5 +10,4 @@ pip install torchtune pip install tiktoken # Install torchao. -TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt) -pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}" +pip install "$(dirname "$0")/../../../third-party/ao" diff --git a/third-party/ao b/third-party/ao new file mode 160000 index 0000000000..75d06933aa --- /dev/null +++ b/third-party/ao @@ -0,0 +1 @@ +Subproject commit 75d06933aace9d1ce803158e52910e4c9fc60981 From bec0625dbb43703719c73f4093dd2e5c7f9264b0 Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:11:23 -0800 Subject: [PATCH 46/59] Fix arm related internal build Differential Revision: D65694379 Pull Request resolved: https://github.com/pytorch/executorch/pull/6743 --- backends/arm/TARGETS | 12 ++++++++++++ backends/arm/operators/TARGETS | 1 + 2 files changed, 13 insertions(+) diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index 0dc8797be5..a73973ad04 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -70,6 +70,18 @@ python_library( ], ) +python_library( + name = "tosa_specification", + srcs = [ + "tosa_specification.py", + ], + typing = True, + deps = [ + "fbsource//third-party/pypi/packaging:packaging", + "//executorch/exir/backend:compile_spec_schema", + ], +) + python_library( name = "tosa_utils", srcs = [ diff --git a/backends/arm/operators/TARGETS b/backends/arm/operators/TARGETS index c2aa8d2dfb..d12cc7e4df 100644 --- a/backends/arm/operators/TARGETS +++ b/backends/arm/operators/TARGETS @@ -7,6 +7,7 @@ python_library( typing = True, deps = [ "//executorch/backends/arm:tosa_mapping", + "//executorch/backends/arm:tosa_specification", ], ) From 789598290cb79a43cedda5351c4cf6f9826935cc Mon Sep 17 00:00:00 2001 From: haowhsu-quic <111341466+haowhsu-quic@users.noreply.github.com> Date: Tue, 12 Nov 2024 03:16:48 +0800 Subject: [PATCH 47/59] Qualcomm AI Engine Direct - wav2letter e2e example Differential Revision: D65734745 Pull Request resolved: https://github.com/pytorch/executorch/pull/5924 --- backends/qualcomm/tests/test_qnn_delegate.py | 38 +++ .../qualcomm/scripts/install_requirement.sh | 2 + examples/qualcomm/scripts/wav2letter.py | 226 ++++++++++++++++++ 3 files changed, 266 insertions(+) create mode 100644 examples/qualcomm/scripts/install_requirement.sh create mode 100644 examples/qualcomm/scripts/wav2letter.py diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4bfdedcd4b..875cfbf956 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2918,6 +2918,44 @@ def test_ptq_mobilebert(self): for k, v in cpu.items(): self.assertLessEqual(abs(v[0] - htp[k][0]), 5) + def test_wav2letter(self): + if not self.required_envs([self.pretrained_weight]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/scripts/wav2letter.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--pretrained_weight", + self.pretrained_weight, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + if self.shared_buffer: + cmds.extend(["--shared_buffer"]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertLessEqual(msg["wer"], 0.5) + self.assertLessEqual(msg["cer"], 0.25) + def test_export_example(self): if not self.required_envs([self.model_name]): self.skipTest("missing required envs") diff --git a/examples/qualcomm/scripts/install_requirement.sh b/examples/qualcomm/scripts/install_requirement.sh new file mode 100644 index 0000000000..c961467a8a --- /dev/null +++ b/examples/qualcomm/scripts/install_requirement.sh @@ -0,0 +1,2 @@ +pip install soundfile +pip install torchmetrics diff --git a/examples/qualcomm/scripts/wav2letter.py b/examples/qualcomm/scripts/wav2letter.py new file mode 100644 index 0000000000..e377c6d7e9 --- /dev/null +++ b/examples/qualcomm/scripts/wav2letter.py @@ -0,0 +1,226 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +from multiprocessing.connection import Client + +import numpy as np + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.models.wav2letter import Wav2LetterModel +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) + + +class Conv2D(torch.nn.Module): + def __init__(self, stride, padding, weight, bias=None): + super().__init__() + use_bias = bias is not None + self.conv = torch.nn.Conv2d( + in_channels=weight.shape[1], + out_channels=weight.shape[0], + kernel_size=[weight.shape[2], 1], + stride=[*stride, 1], + padding=[*padding, 0], + bias=use_bias, + ) + self.conv.weight = torch.nn.Parameter(weight.unsqueeze(-1)) + if use_bias: + self.conv.bias = torch.nn.Parameter(bias) + + def forward(self, x): + return self.conv(x) + + +def get_dataset(data_size, artifact_dir): + from torch.utils.data import DataLoader + from torchaudio.datasets import LIBRISPEECH + + def collate_fun(batch): + waves, labels = [], [] + + for wave, _, text, *_ in batch: + waves.append(wave.squeeze(0)) + labels.append(text) + # need padding here for static ouput shape + waves = torch.nn.utils.rnn.pad_sequence(waves, batch_first=True) + return waves, labels + + dataset = LIBRISPEECH(artifact_dir, url="test-clean", download=True) + data_loader = DataLoader( + dataset=dataset, + batch_size=data_size, + shuffle=True, + collate_fn=lambda x: collate_fun(x), + ) + # prepare input data + inputs, targets, input_list = [], [], "" + for wave, label in data_loader: + for index in range(data_size): + # reshape input tensor to NCHW + inputs.append((wave[index].reshape(1, 1, -1, 1),)) + targets.append(label[index]) + input_list += f"input_{index}_0.raw\n" + # here we only take first batch, i.e. 'data_size' tensors + break + + return inputs, targets, input_list + + +def eval_metric(pred, target_str): + from torchmetrics.text import CharErrorRate, WordErrorRate + + def parse(ids): + vocab = " abcdefghijklmnopqrstuvwxyz'*" + return ["".join([vocab[c] for c in id]).replace("*", "").upper() for id in ids] + + pred_str = parse( + [ + torch.unique_consecutive(pred[i, :, :].argmax(0)) + for i in range(pred.shape[0]) + ] + ) + wer, cer = WordErrorRate(), CharErrorRate() + return wer(pred_str, target_str), cer(pred_str, target_str) + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + instance = Wav2LetterModel() + # target labels " abcdefghijklmnopqrstuvwxyz'*" + instance.vocab_size = 29 + model = instance.get_eager_model().eval() + model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True)) + + # convert conv1d to conv2d in nn.Module level will only introduce 2 permute + # nodes around input & output, which is more quantization friendly. + for i in range(len(model.acoustic_model)): + for j in range(len(model.acoustic_model[i])): + module = model.acoustic_model[i][j] + if isinstance(module, torch.nn.Conv1d): + model.acoustic_model[i][j] = Conv2D( + stride=module.stride, + padding=module.padding, + weight=module.weight, + bias=module.bias, + ) + + # retrieve dataset, will take some time to download + data_num = 100 + inputs, targets, input_list = get_dataset( + data_size=data_num, artifact_dir=args.artifact + ) + pte_filename = "w2l_qnn" + build_executorch_binary( + model, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + sys.exit(0) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + adb.pull(output_path=args.artifact) + + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + # evaluate metrics + wer, cer = 0, 0 + for i, pred in enumerate(predictions): + pred = torch.from_numpy(pred).reshape(1, instance.vocab_size, -1) + wer_eval, cer_eval = eval_metric(pred, targets[i]) + wer += wer_eval + cer += cer_eval + + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps({"wer": wer.item() / data_num, "cer": cer.item() / data_num}) + ) + else: + print(f"wer: {wer / data_num}\ncer: {cer / data_num}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./wav2letter", + default="./wav2letter", + type=str, + ) + + parser.add_argument( + "-p", + "--pretrained_weight", + help=( + "Location of pretrained weight, please download via " + "https://github.com/nipponjo/wav2letter-ctc-pytorch/tree/main?tab=readme-ov-file#wav2letter-ctc-pytorch" + " for torchaudio.models.Wav2Letter version" + ), + default=None, + type=str, + required=True, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) From b23c9e62a9df12c811a90f9916d366af39ce3852 Mon Sep 17 00:00:00 2001 From: David Lin Date: Mon, 11 Nov 2024 12:28:19 -0800 Subject: [PATCH 48/59] [Android] Added instrumentation test for Module (#6751) added instrumentation test --- extension/android_test/.gitignore | 6 + extension/android_test/TARGETS | 1 + extension/android_test/build.gradle | 65 +++++ extension/android_test/gradle.properties | 23 ++ .../android_test/gradle/libs.versions.toml | 12 + .../gradle/wrapper/gradle-wrapper.jar | Bin 0 -> 43462 bytes .../gradle/wrapper/gradle-wrapper.properties | 7 + extension/android_test/gradlew | 249 ++++++++++++++++++ extension/android_test/gradlew.bat | 92 +++++++ extension/android_test/settings.gradle | 24 ++ extension/android_test/setup.sh | 53 ++++ .../executorch/ModuleInstrumentationTest.java | 130 +++++++++ .../src/androidTest/resources/test.txt | 1 + .../android_test/src/main/AndroidManifest.xml | 12 + 14 files changed, 675 insertions(+) create mode 100644 extension/android_test/.gitignore create mode 100644 extension/android_test/TARGETS create mode 100644 extension/android_test/build.gradle create mode 100644 extension/android_test/gradle.properties create mode 100644 extension/android_test/gradle/libs.versions.toml create mode 100644 extension/android_test/gradle/wrapper/gradle-wrapper.jar create mode 100644 extension/android_test/gradle/wrapper/gradle-wrapper.properties create mode 100755 extension/android_test/gradlew create mode 100644 extension/android_test/gradlew.bat create mode 100644 extension/android_test/settings.gradle create mode 100755 extension/android_test/setup.sh create mode 100644 extension/android_test/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java create mode 100644 extension/android_test/src/androidTest/resources/test.txt create mode 100644 extension/android_test/src/main/AndroidManifest.xml diff --git a/extension/android_test/.gitignore b/extension/android_test/.gitignore new file mode 100644 index 0000000000..a43b7e827a --- /dev/null +++ b/extension/android_test/.gitignore @@ -0,0 +1,6 @@ +local.properties +.gradle +.idea/* +.externalNativeBuild +src/libs/* +build diff --git a/extension/android_test/TARGETS b/extension/android_test/TARGETS new file mode 100644 index 0000000000..5c4f482b5e --- /dev/null +++ b/extension/android_test/TARGETS @@ -0,0 +1 @@ +# This file needs to exist to avoid build system breakage, see https://fburl.com/workplace/jtdlgdmd diff --git a/extension/android_test/build.gradle b/extension/android_test/build.gradle new file mode 100644 index 0000000000..5beb5455cb --- /dev/null +++ b/extension/android_test/build.gradle @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + buildscript { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } + dependencies { + classpath 'com.android.tools.build:gradle:7.3.0' + } +} + + +apply plugin: 'com.android.library' + +group 'org.pytorch.executorch' + + +android { + namespace 'org.pytorch.executorch' + compileSdkVersion 31 + buildToolsVersion "29.0.0" + defaultConfig { + minSdkVersion 28 + targetSdkVersion 31 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + sourceSets { + androidTest { + resources.srcDirs += [ 'src/androidTest/resources' ] + } + } +} + +dependencies { + implementation 'com.facebook.soloader:nativeloader:0.10.5' + implementation("com.facebook.fbjni:fbjni:0.5.1") + implementation(files("src/libs/executorch.aar")) + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.5' + androidTestImplementation 'androidx.test:rules:1.2.0' + androidTestImplementation 'commons-io:commons-io:2.4' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' + androidTestImplementation 'com.google.gms:google-services:4.3.3' +} + +task('setupNativeLibs', type: Exec){ + commandLine("sh", "setup.sh") +} + +gradle.projectsEvaluated { + preBuild.dependsOn setupNativeLibs +} diff --git a/extension/android_test/gradle.properties b/extension/android_test/gradle.properties new file mode 100644 index 0000000000..2cbd6d19d3 --- /dev/null +++ b/extension/android_test/gradle.properties @@ -0,0 +1,23 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app's APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true +# Kotlin code style for this project: "official" or "obsolete": +kotlin.code.style=official +# Enables namespacing of each library's R class so that its R class includes only the +# resources declared in the library itself and none from the library's dependencies, +# thereby reducing the size of the R class for that library +android.nonTransitiveRClass=true diff --git a/extension/android_test/gradle/libs.versions.toml b/extension/android_test/gradle/libs.versions.toml new file mode 100644 index 0000000000..561988cb1f --- /dev/null +++ b/extension/android_test/gradle/libs.versions.toml @@ -0,0 +1,12 @@ +# This file was generated by the Gradle 'init' task. +# https://docs.gradle.org/current/userguide/platforms.html#sub::toml-dependencies-format + +[versions] +commons-math3 = "3.6.1" +guava = "32.1.3-jre" +junit = "4.13.2" + +[libraries] +commons-math3 = { module = "org.apache.commons:commons-math3", version.ref = "commons-math3" } +guava = { module = "com.google.guava:guava", version.ref = "guava" } +junit = { module = "junit:junit", version.ref = "junit" } diff --git a/extension/android_test/gradle/wrapper/gradle-wrapper.jar b/extension/android_test/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..d64cd4917707c1f8861d8cb53dd15194d4248596 GIT binary patch literal 43462 zcma&NWl&^owk(X(xVyW%ySuwf;qI=D6|RlDJ2cR^yEKh!@I- zp9QeisK*rlxC>+~7Dk4IxIRsKBHqdR9b3+fyL=ynHmIDe&|>O*VlvO+%z5;9Z$|DJ zb4dO}-R=MKr^6EKJiOrJdLnCJn>np?~vU-1sSFgPu;pthGwf}bG z(1db%xwr#x)r+`4AGu$j7~u2MpVs3VpLp|mx&;>`0p0vH6kF+D2CY0fVdQOZ@h;A` z{infNyvmFUiu*XG}RNMNwXrbec_*a3N=2zJ|Wh5z* z5rAX$JJR{#zP>KY**>xHTuw?|-Rg|o24V)74HcfVT;WtQHXlE+_4iPE8QE#DUm%x0 zEKr75ur~W%w#-My3Tj`hH6EuEW+8K-^5P62$7Sc5OK+22qj&Pd1;)1#4tKihi=~8C zHiQSst0cpri6%OeaR`PY>HH_;CPaRNty%WTm4{wDK8V6gCZlG@U3$~JQZ;HPvDJcT1V{ z?>H@13MJcCNe#5z+MecYNi@VT5|&UiN1D4ATT+%M+h4c$t;C#UAs3O_q=GxK0}8%8 z8J(_M9bayxN}69ex4dzM_P3oh@ZGREjVvn%%r7=xjkqxJP4kj}5tlf;QosR=%4L5y zWhgejO=vao5oX%mOHbhJ8V+SG&K5dABn6!WiKl{|oPkq(9z8l&Mm%(=qGcFzI=eLu zWc_oCLyf;hVlB@dnwY98?75B20=n$>u3b|NB28H0u-6Rpl((%KWEBOfElVWJx+5yg z#SGqwza7f}$z;n~g%4HDU{;V{gXIhft*q2=4zSezGK~nBgu9-Q*rZ#2f=Q}i2|qOp z!!y4p)4o=LVUNhlkp#JL{tfkhXNbB=Ox>M=n6soptJw-IDI|_$is2w}(XY>a=H52d z3zE$tjPUhWWS+5h=KVH&uqQS=$v3nRs&p$%11b%5qtF}S2#Pc`IiyBIF4%A!;AVoI zXU8-Rpv!DQNcF~(qQnyyMy=-AN~U>#&X1j5BLDP{?K!%h!;hfJI>$mdLSvktEr*89 zdJHvby^$xEX0^l9g$xW-d?J;L0#(`UT~zpL&*cEh$L|HPAu=P8`OQZV!-}l`noSp_ zQ-1$q$R-gDL)?6YaM!=8H=QGW$NT2SeZlb8PKJdc=F-cT@j7Xags+Pr*jPtlHFnf- zh?q<6;)27IdPc^Wdy-mX%2s84C1xZq9Xms+==F4);O`VUASmu3(RlgE#0+#giLh-& zcxm3_e}n4{%|X zJp{G_j+%`j_q5}k{eW&TlP}J2wtZ2^<^E(O)4OQX8FDp6RJq!F{(6eHWSD3=f~(h} zJXCf7=r<16X{pHkm%yzYI_=VDP&9bmI1*)YXZeB}F? z(%QsB5fo*FUZxK$oX~X^69;x~j7ms8xlzpt-T15e9}$4T-pC z6PFg@;B-j|Ywajpe4~bk#S6(fO^|mm1hKOPfA%8-_iGCfICE|=P_~e;Wz6my&)h_~ zkv&_xSAw7AZ%ThYF(4jADW4vg=oEdJGVOs>FqamoL3Np8>?!W#!R-0%2Bg4h?kz5I zKV-rKN2n(vUL%D<4oj@|`eJ>0i#TmYBtYmfla;c!ATW%;xGQ0*TW@PTlGG><@dxUI zg>+3SiGdZ%?5N=8uoLA|$4isK$aJ%i{hECP$bK{J#0W2gQ3YEa zZQ50Stn6hqdfxJ*9#NuSLwKFCUGk@c=(igyVL;;2^wi4o30YXSIb2g_ud$ zgpCr@H0qWtk2hK8Q|&wx)}4+hTYlf;$a4#oUM=V@Cw#!$(nOFFpZ;0lc!qd=c$S}Z zGGI-0jg~S~cgVT=4Vo)b)|4phjStD49*EqC)IPwyeKBLcN;Wu@Aeph;emROAwJ-0< z_#>wVm$)ygH|qyxZaet&(Vf%pVdnvKWJn9`%DAxj3ot;v>S$I}jJ$FLBF*~iZ!ZXE zkvui&p}fI0Y=IDX)mm0@tAd|fEHl~J&K}ZX(Mm3cm1UAuwJ42+AO5@HwYfDH7ipIc zmI;1J;J@+aCNG1M`Btf>YT>~c&3j~Qi@Py5JT6;zjx$cvOQW@3oQ>|}GH?TW-E z1R;q^QFjm5W~7f}c3Ww|awg1BAJ^slEV~Pk`Kd`PS$7;SqJZNj->it4DW2l15}xP6 zoCl$kyEF%yJni0(L!Z&14m!1urXh6Btj_5JYt1{#+H8w?5QI%% zo-$KYWNMJVH?Hh@1n7OSu~QhSswL8x0=$<8QG_zepi_`y_79=nK=_ZP_`Em2UI*tyQoB+r{1QYZCpb?2OrgUw#oRH$?^Tj!Req>XiE#~B|~ z+%HB;=ic+R@px4Ld8mwpY;W^A%8%l8$@B@1m5n`TlKI6bz2mp*^^^1mK$COW$HOfp zUGTz-cN9?BGEp}5A!mDFjaiWa2_J2Iq8qj0mXzk; z66JBKRP{p%wN7XobR0YjhAuW9T1Gw3FDvR5dWJ8ElNYF94eF3ebu+QwKjtvVu4L zI9ip#mQ@4uqVdkl-TUQMb^XBJVLW(-$s;Nq;@5gr4`UfLgF$adIhd?rHOa%D);whv z=;krPp~@I+-Z|r#s3yCH+c1US?dnm+C*)r{m+86sTJusLdNu^sqLrfWed^ndHXH`m zd3#cOe3>w-ga(Dus_^ppG9AC>Iq{y%%CK+Cro_sqLCs{VLuK=dev>OL1dis4(PQ5R zcz)>DjEkfV+MO;~>VUlYF00SgfUo~@(&9$Iy2|G0T9BSP?&T22>K46D zL*~j#yJ?)^*%J3!16f)@Y2Z^kS*BzwfAQ7K96rFRIh>#$*$_Io;z>ux@}G98!fWR@ zGTFxv4r~v)Gsd|pF91*-eaZ3Qw1MH$K^7JhWIdX%o$2kCbvGDXy)a?@8T&1dY4`;L z4Kn+f%SSFWE_rpEpL9bnlmYq`D!6F%di<&Hh=+!VI~j)2mfil03T#jJ_s?}VV0_hp z7T9bWxc>Jm2Z0WMU?`Z$xE74Gu~%s{mW!d4uvKCx@WD+gPUQ zV0vQS(Ig++z=EHN)BR44*EDSWIyT~R4$FcF*VEY*8@l=218Q05D2$|fXKFhRgBIEE zdDFB}1dKkoO^7}{5crKX!p?dZWNz$m>1icsXG2N+((x0OIST9Zo^DW_tytvlwXGpn zs8?pJXjEG;T@qrZi%#h93?FP$!&P4JA(&H61tqQi=opRzNpm zkrG}$^t9&XduK*Qa1?355wd8G2CI6QEh@Ua>AsD;7oRUNLPb76m4HG3K?)wF~IyS3`fXuNM>${?wmB zpVz;?6_(Fiadfd{vUCBM*_kt$+F3J+IojI;9L(gc9n3{sEZyzR9o!_mOwFC#tQ{Q~ zP3-`#uK#tP3Q7~Q;4H|wjZHO8h7e4IuBxl&vz2w~D8)w=Wtg31zpZhz%+kzSzL*dV zwp@{WU4i;hJ7c2f1O;7Mz6qRKeASoIv0_bV=i@NMG*l<#+;INk-^`5w@}Dj~;k=|}qM1vq_P z|GpBGe_IKq|LNy9SJhKOQ$c=5L{Dv|Q_lZl=-ky*BFBJLW9&y_C|!vyM~rQx=!vun z?rZJQB5t}Dctmui5i31C_;_}CEn}_W%>oSXtt>@kE1=JW*4*v4tPp;O6 zmAk{)m!)}34pTWg8{i>($%NQ(Tl;QC@J@FfBoc%Gr&m560^kgSfodAFrIjF}aIw)X zoXZ`@IsMkc8_=w%-7`D6Y4e*CG8k%Ud=GXhsTR50jUnm+R*0A(O3UKFg0`K;qp1bl z7``HN=?39ic_kR|^R^~w-*pa?Vj#7|e9F1iRx{GN2?wK!xR1GW!qa=~pjJb-#u1K8 zeR?Y2i-pt}yJq;SCiVHODIvQJX|ZJaT8nO+(?HXbLefulKKgM^B(UIO1r+S=7;kLJ zcH}1J=Px2jsh3Tec&v8Jcbng8;V-`#*UHt?hB(pmOipKwf3Lz8rG$heEB30Sg*2rx zV<|KN86$soN(I!BwO`1n^^uF2*x&vJ$2d$>+`(romzHP|)K_KkO6Hc>_dwMW-M(#S zK(~SiXT1@fvc#U+?|?PniDRm01)f^#55;nhM|wi?oG>yBsa?~?^xTU|fX-R(sTA+5 zaq}-8Tx7zrOy#3*JLIIVsBmHYLdD}!0NP!+ITW+Thn0)8SS!$@)HXwB3tY!fMxc#1 zMp3H?q3eD?u&Njx4;KQ5G>32+GRp1Ee5qMO0lZjaRRu&{W<&~DoJNGkcYF<5(Ab+J zgO>VhBl{okDPn78<%&e2mR{jwVCz5Og;*Z;;3%VvoGo_;HaGLWYF7q#jDX=Z#Ml`H z858YVV$%J|e<1n`%6Vsvq7GmnAV0wW4$5qQ3uR@1i>tW{xrl|ExywIc?fNgYlA?C5 zh$ezAFb5{rQu6i7BSS5*J-|9DQ{6^BVQ{b*lq`xS@RyrsJN?-t=MTMPY;WYeKBCNg z^2|pN!Q^WPJuuO4!|P@jzt&tY1Y8d%FNK5xK(!@`jO2aEA*4 zkO6b|UVBipci?){-Ke=+1;mGlND8)6+P;8sq}UXw2hn;fc7nM>g}GSMWu&v&fqh

iViYT=fZ(|3Ox^$aWPp4a8h24tD<|8-!aK0lHgL$N7Efw}J zVIB!7=T$U`ao1?upi5V4Et*-lTG0XvExbf!ya{cua==$WJyVG(CmA6Of*8E@DSE%L z`V^$qz&RU$7G5mg;8;=#`@rRG`-uS18$0WPN@!v2d{H2sOqP|!(cQ@ zUHo!d>>yFArLPf1q`uBvY32miqShLT1B@gDL4XoVTK&@owOoD)OIHXrYK-a1d$B{v zF^}8D3Y^g%^cnvScOSJR5QNH+BI%d|;J;wWM3~l>${fb8DNPg)wrf|GBP8p%LNGN# z3EaIiItgwtGgT&iYCFy9-LG}bMI|4LdmmJt@V@% zb6B)1kc=T)(|L@0;wr<>=?r04N;E&ef+7C^`wPWtyQe(*pD1pI_&XHy|0gIGHMekd zF_*M4yi6J&Z4LQj65)S zXwdM{SwUo%3SbPwFsHgqF@V|6afT|R6?&S;lw=8% z3}@9B=#JI3@B*#4s!O))~z zc>2_4Q_#&+5V`GFd?88^;c1i7;Vv_I*qt!_Yx*n=;rj!82rrR2rQ8u5(Ejlo{15P% zs~!{%XJ>FmJ})H^I9bn^Re&38H{xA!0l3^89k(oU;bZWXM@kn$#aoS&Y4l^-WEn-fH39Jb9lA%s*WsKJQl?n9B7_~P z-XM&WL7Z!PcoF6_D>V@$CvUIEy=+Z&0kt{szMk=f1|M+r*a43^$$B^MidrT0J;RI` z(?f!O<8UZkm$_Ny$Hth1J#^4ni+im8M9mr&k|3cIgwvjAgjH z8`N&h25xV#v*d$qBX5jkI|xOhQn!>IYZK7l5#^P4M&twe9&Ey@@GxYMxBZq2e7?`q z$~Szs0!g{2fGcp9PZEt|rdQ6bhAgpcLHPz?f-vB?$dc*!9OL?Q8mn7->bFD2Si60* z!O%y)fCdMSV|lkF9w%x~J*A&srMyYY3{=&$}H zGQ4VG_?$2X(0|vT0{=;W$~icCI{b6W{B!Q8xdGhF|D{25G_5_+%s(46lhvNLkik~R z>nr(&C#5wwOzJZQo9m|U<;&Wk!_#q|V>fsmj1g<6%hB{jGoNUPjgJslld>xmODzGjYc?7JSuA?A_QzjDw5AsRgi@Y|Z0{F{!1=!NES-#*f^s4l0Hu zz468))2IY5dmD9pa*(yT5{EyP^G>@ZWumealS-*WeRcZ}B%gxq{MiJ|RyX-^C1V=0 z@iKdrGi1jTe8Ya^x7yyH$kBNvM4R~`fbPq$BzHum-3Zo8C6=KW@||>zsA8-Y9uV5V z#oq-f5L5}V<&wF4@X@<3^C%ptp6+Ce)~hGl`kwj)bsAjmo_GU^r940Z-|`<)oGnh7 zFF0Tde3>ui?8Yj{sF-Z@)yQd~CGZ*w-6p2U<8}JO-sRsVI5dBji`01W8A&3$?}lxBaC&vn0E$c5tW* zX>5(zzZ=qn&!J~KdsPl;P@bmA-Pr8T*)eh_+Dv5=Ma|XSle6t(k8qcgNyar{*ReQ8 zTXwi=8vr>!3Ywr+BhggHDw8ke==NTQVMCK`$69fhzEFB*4+H9LIvdt-#IbhZvpS}} zO3lz;P?zr0*0$%-Rq_y^k(?I{Mk}h@w}cZpMUp|ucs55bcloL2)($u%mXQw({Wzc~ z;6nu5MkjP)0C(@%6Q_I_vsWrfhl7Zpoxw#WoE~r&GOSCz;_ro6i(^hM>I$8y>`!wW z*U^@?B!MMmb89I}2(hcE4zN2G^kwyWCZp5JG>$Ez7zP~D=J^LMjSM)27_0B_X^C(M z`fFT+%DcKlu?^)FCK>QzSnV%IsXVcUFhFdBP!6~se&xxrIxsvySAWu++IrH;FbcY$ z2DWTvSBRfLwdhr0nMx+URA$j3i7_*6BWv#DXfym?ZRDcX9C?cY9sD3q)uBDR3uWg= z(lUIzB)G$Hr!){>E{s4Dew+tb9kvToZp-1&c?y2wn@Z~(VBhqz`cB;{E4(P3N2*nJ z_>~g@;UF2iG{Kt(<1PyePTKahF8<)pozZ*xH~U-kfoAayCwJViIrnqwqO}7{0pHw$ zs2Kx?s#vQr7XZ264>5RNKSL8|Ty^=PsIx^}QqOOcfpGUU4tRkUc|kc7-!Ae6!+B{o~7nFpm3|G5^=0#Bnm6`V}oSQlrX(u%OWnC zoLPy&Q;1Jui&7ST0~#+}I^&?vcE*t47~Xq#YwvA^6^} z`WkC)$AkNub|t@S!$8CBlwbV~?yp&@9h{D|3z-vJXgzRC5^nYm+PyPcgRzAnEi6Q^gslXYRv4nycsy-SJu?lMps-? zV`U*#WnFsdPLL)Q$AmD|0`UaC4ND07+&UmOu!eHruzV|OUox<+Jl|Mr@6~C`T@P%s zW7sgXLF2SSe9Fl^O(I*{9wsFSYb2l%-;&Pi^dpv!{)C3d0AlNY6!4fgmSgj_wQ*7Am7&$z;Jg&wgR-Ih;lUvWS|KTSg!&s_E9_bXBkZvGiC6bFKDWZxsD$*NZ#_8bl zG1P-#@?OQzED7@jlMJTH@V!6k;W>auvft)}g zhoV{7$q=*;=l{O>Q4a@ ziMjf_u*o^PsO)#BjC%0^h>Xp@;5$p{JSYDt)zbb}s{Kbt!T*I@Pk@X0zds6wsefuU zW$XY%yyRGC94=6mf?x+bbA5CDQ2AgW1T-jVAJbm7K(gp+;v6E0WI#kuACgV$r}6L? zd|Tj?^%^*N&b>Dd{Wr$FS2qI#Ucs1yd4N+RBUQiSZGujH`#I)mG&VKoDh=KKFl4=G z&MagXl6*<)$6P}*Tiebpz5L=oMaPrN+caUXRJ`D?=K9!e0f{@D&cZLKN?iNP@X0aF zE(^pl+;*T5qt?1jRC=5PMgV!XNITRLS_=9{CJExaQj;lt!&pdzpK?8p>%Mb+D z?yO*uSung=-`QQ@yX@Hyd4@CI^r{2oiu`%^bNkz+Nkk!IunjwNC|WcqvX~k=><-I3 zDQdbdb|!v+Iz01$w@aMl!R)koD77Xp;eZwzSl-AT zr@Vu{=xvgfq9akRrrM)}=!=xcs+U1JO}{t(avgz`6RqiiX<|hGG1pmop8k6Q+G_mv zJv|RfDheUp2L3=^C=4aCBMBn0aRCU(DQwX-W(RkRwmLeuJYF<0urcaf(=7)JPg<3P zQs!~G)9CT18o!J4{zX{_e}4eS)U-E)0FAt}wEI(c0%HkxgggW;(1E=>J17_hsH^sP z%lT0LGgbUXHx-K*CI-MCrP66UP0PvGqM$MkeLyqHdbgP|_Cm!7te~b8p+e6sQ_3k| zVcwTh6d83ltdnR>D^)BYQpDKlLk3g0Hdcgz2}%qUs9~~Rie)A-BV1mS&naYai#xcZ z(d{8=-LVpTp}2*y)|gR~;qc7fp26}lPcLZ#=JpYcn3AT9(UIdOyg+d(P5T7D&*P}# zQCYplZO5|7+r19%9e`v^vfSS1sbX1c%=w1;oyruXB%Kl$ACgKQ6=qNWLsc=28xJjg zwvsI5-%SGU|3p>&zXVl^vVtQT3o-#$UT9LI@Npz~6=4!>mc431VRNN8od&Ul^+G_kHC`G=6WVWM z%9eWNyy(FTO|A+@x}Ou3CH)oi;t#7rAxdIXfNFwOj_@Y&TGz6P_sqiB`Q6Lxy|Q{`|fgmRG(k+!#b*M+Z9zFce)f-7;?Km5O=LHV9f9_87; zF7%R2B+$?@sH&&-$@tzaPYkw0;=i|;vWdI|Wl3q_Zu>l;XdIw2FjV=;Mq5t1Q0|f< zs08j54Bp`3RzqE=2enlkZxmX6OF+@|2<)A^RNQpBd6o@OXl+i)zO%D4iGiQNuXd+zIR{_lb96{lc~bxsBveIw6umhShTX+3@ZJ=YHh@ zWY3(d0azg;7oHn>H<>?4@*RQbi>SmM=JrHvIG(~BrvI)#W(EAeO6fS+}mxxcc+X~W6&YVl86W9WFSS}Vz-f9vS?XUDBk)3TcF z8V?$4Q)`uKFq>xT=)Y9mMFVTUk*NIA!0$?RP6Ig0TBmUFrq*Q-Agq~DzxjStQyJ({ zBeZ;o5qUUKg=4Hypm|}>>L=XKsZ!F$yNTDO)jt4H0gdQ5$f|d&bnVCMMXhNh)~mN z@_UV6D7MVlsWz+zM+inZZp&P4fj=tm6fX)SG5H>OsQf_I8c~uGCig$GzuwViK54bcgL;VN|FnyQl>Ed7(@>=8$a_UKIz|V6CeVSd2(P z0Uu>A8A+muM%HLFJQ9UZ5c)BSAv_zH#1f02x?h9C}@pN@6{>UiAp>({Fn(T9Q8B z^`zB;kJ5b`>%dLm+Ol}ty!3;8f1XDSVX0AUe5P#@I+FQ-`$(a;zNgz)4x5hz$Hfbg z!Q(z26wHLXko(1`;(BAOg_wShpX0ixfWq3ponndY+u%1gyX)_h=v1zR#V}#q{au6; z!3K=7fQwnRfg6FXtNQmP>`<;!N137paFS%y?;lb1@BEdbvQHYC{976l`cLqn;b8lp zIDY>~m{gDj(wfnK!lpW6pli)HyLEiUrNc%eXTil|F2s(AY+LW5hkKb>TQ3|Q4S9rr zpDs4uK_co6XPsn_z$LeS{K4jFF`2>U`tbgKdyDne`xmR<@6AA+_hPNKCOR-Zqv;xk zu5!HsBUb^!4uJ7v0RuH-7?l?}b=w5lzzXJ~gZcxRKOovSk@|#V+MuX%Y+=;14i*%{)_gSW9(#4%)AV#3__kac1|qUy!uyP{>?U#5wYNq}y$S9pCc zFc~4mgSC*G~j0u#qqp9 z${>3HV~@->GqEhr_Xwoxq?Hjn#=s2;i~g^&Hn|aDKpA>Oc%HlW(KA1?BXqpxB;Ydx)w;2z^MpjJ(Qi(X!$5RC z*P{~%JGDQqojV>2JbEeCE*OEu!$XJ>bWA9Oa_Hd;y)F%MhBRi*LPcdqR8X`NQ&1L# z5#9L*@qxrx8n}LfeB^J{%-?SU{FCwiWyHp682F+|pa+CQa3ZLzBqN1{)h4d6+vBbV zC#NEbQLC;}me3eeYnOG*nXOJZEU$xLZ1<1Y=7r0(-U0P6-AqwMAM`a(Ed#7vJkn6plb4eI4?2y3yOTGmmDQ!z9`wzbf z_OY#0@5=bnep;MV0X_;;SJJWEf^E6Bd^tVJ9znWx&Ks8t*B>AM@?;D4oWUGc z!H*`6d7Cxo6VuyS4Eye&L1ZRhrRmN6Lr`{NL(wDbif|y&z)JN>Fl5#Wi&mMIr5i;x zBx}3YfF>>8EC(fYnmpu~)CYHuHCyr5*`ECap%t@y=jD>!_%3iiE|LN$mK9>- zHdtpy8fGZtkZF?%TW~29JIAfi2jZT8>OA7=h;8T{{k?c2`nCEx9$r zS+*&vt~2o^^J+}RDG@+9&M^K*z4p{5#IEVbz`1%`m5c2};aGt=V?~vIM}ZdPECDI)47|CWBCfDWUbxBCnmYivQ*0Nu_xb*C>~C9(VjHM zxe<*D<#dQ8TlpMX2c@M<9$w!RP$hpG4cs%AI){jp*Sj|*`m)5(Bw*A0$*i-(CA5#%>a)$+jI2C9r6|(>J8InryENI z$NohnxDUB;wAYDwrb*!N3noBTKPpPN}~09SEL18tkG zxgz(RYU_;DPT{l?Q$+eaZaxnsWCA^ds^0PVRkIM%bOd|G2IEBBiz{&^JtNsODs;5z zICt_Zj8wo^KT$7Bg4H+y!Df#3mbl%%?|EXe!&(Vmac1DJ*y~3+kRKAD=Ovde4^^%~ zw<9av18HLyrf*_>Slp;^i`Uy~`mvBjZ|?Ad63yQa#YK`4+c6;pW4?XIY9G1(Xh9WO8{F-Aju+nS9Vmv=$Ac0ienZ+p9*O%NG zMZKy5?%Z6TAJTE?o5vEr0r>f>hb#2w2U3DL64*au_@P!J!TL`oH2r*{>ffu6|A7tv zL4juf$DZ1MW5ZPsG!5)`k8d8c$J$o;%EIL0va9&GzWvkS%ZsGb#S(?{!UFOZ9<$a| zY|a+5kmD5N&{vRqkgY>aHsBT&`rg|&kezoD)gP0fsNYHsO#TRc_$n6Lf1Z{?+DLziXlHrq4sf(!>O{?Tj;Eh@%)+nRE_2VxbN&&%%caU#JDU%vL3}Cb zsb4AazPI{>8H&d=jUaZDS$-0^AxE@utGs;-Ez_F(qC9T=UZX=>ok2k2 ziTn{K?y~a5reD2A)P${NoI^>JXn>`IeArow(41c-Wm~)wiryEP(OS{YXWi7;%dG9v zI?mwu1MxD{yp_rrk!j^cKM)dc4@p4Ezyo%lRN|XyD}}>v=Xoib0gOcdXrQ^*61HNj z=NP|pd>@yfvr-=m{8$3A8TQGMTE7g=z!%yt`8`Bk-0MMwW~h^++;qyUP!J~ykh1GO z(FZ59xuFR$(WE;F@UUyE@Sp>`aVNjyj=Ty>_Vo}xf`e7`F;j-IgL5`1~-#70$9_=uBMq!2&1l zomRgpD58@)YYfvLtPW}{C5B35R;ZVvB<<#)x%srmc_S=A7F@DW8>QOEGwD6suhwCg z>Pa+YyULhmw%BA*4yjDp|2{!T98~<6Yfd(wo1mQ!KWwq0eg+6)o1>W~f~kL<-S+P@$wx*zeI|1t7z#Sxr5 zt6w+;YblPQNplq4Z#T$GLX#j6yldXAqj>4gAnnWtBICUnA&-dtnlh=t0Ho_vEKwV` z)DlJi#!@nkYV#$!)@>udAU*hF?V`2$Hf=V&6PP_|r#Iv*J$9)pF@X3`k;5})9^o4y z&)~?EjX5yX12O(BsFy-l6}nYeuKkiq`u9145&3Ssg^y{5G3Pse z9w(YVa0)N-fLaBq1`P!_#>SS(8fh_5!f{UrgZ~uEdeMJIz7DzI5!NHHqQtm~#CPij z?=N|J>nPR6_sL7!f4hD_|KH`vf8(Wpnj-(gPWH+ZvID}%?~68SwhPTC3u1_cB`otq z)U?6qo!ZLi5b>*KnYHWW=3F!p%h1;h{L&(Q&{qY6)_qxNfbP6E3yYpW!EO+IW3?@J z);4>g4gnl^8klu7uA>eGF6rIGSynacogr)KUwE_R4E5Xzi*Qir@b-jy55-JPC8c~( zo!W8y9OGZ&`xmc8;=4-U9=h{vCqfCNzYirONmGbRQlR`WWlgnY+1wCXbMz&NT~9*| z6@FrzP!LX&{no2!Ln_3|I==_4`@}V?4a;YZKTdw;vT<+K+z=uWbW(&bXEaWJ^W8Td z-3&1bY^Z*oM<=M}LVt>_j+p=2Iu7pZmbXrhQ_k)ysE9yXKygFNw$5hwDn(M>H+e1&9BM5!|81vd%r%vEm zqxY3?F@fb6O#5UunwgAHR9jp_W2zZ}NGp2%mTW@(hz7$^+a`A?mb8|_G*GNMJ) zjqegXQio=i@AINre&%ofexAr95aop5C+0MZ0m-l=MeO8m3epm7U%vZB8+I+C*iNFM z#T3l`gknX;D$-`2XT^Cg*vrv=RH+P;_dfF++cP?B_msQI4j+lt&rX2)3GaJx%W*Nn zkML%D{z5tpHH=dksQ*gzc|}gzW;lwAbxoR07VNgS*-c3d&8J|;@3t^ zVUz*J*&r7DFRuFVDCJDK8V9NN5hvpgGjwx+5n)qa;YCKe8TKtdnh{I7NU9BCN!0dq zczrBk8pE{{@vJa9ywR@mq*J=v+PG;?fwqlJVhijG!3VmIKs>9T6r7MJpC)m!Tc#>g zMtVsU>wbwFJEfwZ{vB|ZlttNe83)$iz`~#8UJ^r)lJ@HA&G#}W&ZH*;k{=TavpjWE z7hdyLZPf*X%Gm}i`Y{OGeeu^~nB8=`{r#TUrM-`;1cBvEd#d!kPqIgYySYhN-*1;L z^byj%Yi}Gx)Wnkosi337BKs}+5H5dth1JA{Ir-JKN$7zC)*}hqeoD(WfaUDPT>0`- z(6sa0AoIqASwF`>hP}^|)a_j2s^PQn*qVC{Q}htR z5-)duBFXT_V56-+UohKXlq~^6uf!6sA#ttk1o~*QEy_Y-S$gAvq47J9Vtk$5oA$Ct zYhYJ@8{hsC^98${!#Ho?4y5MCa7iGnfz}b9jE~h%EAAv~Qxu)_rAV;^cygV~5r_~?l=B`zObj7S=H=~$W zPtI_m%g$`kL_fVUk9J@>EiBH zOO&jtn~&`hIFMS5S`g8w94R4H40mdNUH4W@@XQk1sr17b{@y|JB*G9z1|CrQjd+GX z6+KyURG3;!*BQrentw{B2R&@2&`2}n(z-2&X7#r!{yg@Soy}cRD~j zj9@UBW+N|4HW4AWapy4wfUI- zZ`gSL6DUlgj*f1hSOGXG0IVH8HxK?o2|3HZ;KW{K+yPAlxtb)NV_2AwJm|E)FRs&& z=c^e7bvUsztY|+f^k7NXs$o1EUq>cR7C0$UKi6IooHWlK_#?IWDkvywnzg&ThWo^? z2O_N{5X39#?eV9l)xI(>@!vSB{DLt*oY!K1R8}_?%+0^C{d9a%N4 zoxHVT1&Lm|uDX%$QrBun5e-F`HJ^T$ zmzv)p@4ZHd_w9!%Hf9UYNvGCw2TTTbrj9pl+T9%-_-}L(tES>Or-}Z4F*{##n3~L~TuxjirGuIY#H7{%$E${?p{Q01 zi6T`n;rbK1yIB9jmQNycD~yZq&mbIsFWHo|ZAChSFPQa<(%d8mGw*V3fh|yFoxOOiWJd(qvVb!Z$b88cg->N=qO*4k~6;R==|9ihg&riu#P~s4Oap9O7f%crSr^rljeIfXDEg>wi)&v*a%7zpz<9w z*r!3q9J|390x`Zk;g$&OeN&ctp)VKRpDSV@kU2Q>jtok($Y-*x8_$2piTxun81@vt z!Vj?COa0fg2RPXMSIo26T=~0d`{oGP*eV+$!0I<(4azk&Vj3SiG=Q!6mX0p$z7I}; z9BJUFgT-K9MQQ-0@Z=^7R<{bn2Fm48endsSs`V7_@%8?Bxkqv>BDoVcj?K#dV#uUP zL1ND~?D-|VGKe3Rw_7-Idpht>H6XRLh*U7epS6byiGvJpr%d}XwfusjH9g;Z98H`x zyde%%5mhGOiL4wljCaWCk-&uE4_OOccb9c!ZaWt4B(wYl!?vyzl%7n~QepN&eFUrw zFIOl9c({``6~QD+43*_tzP{f2x41h(?b43^y6=iwyB)2os5hBE!@YUS5?N_tXd=h( z)WE286Fbd>R4M^P{!G)f;h<3Q>Fipuy+d2q-)!RyTgt;wr$(?9ox3;q+{E*ZQHhOn;lM`cjnu9 zXa48ks-v(~b*;MAI<>YZH(^NV8vjb34beE<_cwKlJoR;k6lJNSP6v}uiyRD?|0w+X@o1ONrH8a$fCxXpf? z?$DL0)7|X}Oc%h^zrMKWc-NS9I0Utu@>*j}b@tJ=ixQSJ={4@854wzW@E>VSL+Y{i z#0b=WpbCZS>kUCO_iQz)LoE>P5LIG-hv9E+oG}DtlIDF>$tJ1aw9^LuhLEHt?BCj& z(O4I8v1s#HUi5A>nIS-JK{v!7dJx)^Yg%XjNmlkWAq2*cv#tHgz`Y(bETc6CuO1VkN^L-L3j_x<4NqYb5rzrLC-7uOv z!5e`GZt%B782C5-fGnn*GhDF$%(qP<74Z}3xx+{$4cYKy2ikxI7B2N+2r07DN;|-T->nU&!=Cm#rZt%O_5c&1Z%nlWq3TKAW0w zQqemZw_ue--2uKQsx+niCUou?HjD`xhEjjQd3%rrBi82crq*~#uA4+>vR<_S{~5ce z-2EIl?~s z1=GVL{NxP1N3%=AOaC}j_Fv=ur&THz zyO!d9kHq|c73kpq`$+t+8Bw7MgeR5~`d7ChYyGCBWSteTB>8WAU(NPYt2Dk`@#+}= zI4SvLlyk#pBgVigEe`?NG*vl7V6m+<}%FwPV=~PvvA)=#ths==DRTDEYh4V5}Cf$z@#;< zyWfLY_5sP$gc3LLl2x+Ii)#b2nhNXJ{R~vk`s5U7Nyu^3yFg&D%Txwj6QezMX`V(x z=C`{76*mNb!qHHs)#GgGZ_7|vkt9izl_&PBrsu@}L`X{95-2jf99K)0=*N)VxBX2q z((vkpP2RneSIiIUEnGb?VqbMb=Zia+rF~+iqslydE34cSLJ&BJW^3knX@M;t*b=EA zNvGzv41Ld_T+WT#XjDB840vovUU^FtN_)G}7v)1lPetgpEK9YS^OWFkPoE{ovj^=@ zO9N$S=G$1ecndT_=5ehth2Lmd1II-PuT~C9`XVePw$y8J#dpZ?Tss<6wtVglm(Ok7 z3?^oi@pPio6l&!z8JY(pJvG=*pI?GIOu}e^EB6QYk$#FJQ%^AIK$I4epJ+9t?KjqA+bkj&PQ*|vLttme+`9G=L% ziadyMw_7-M)hS(3E$QGNCu|o23|%O+VN7;Qggp?PB3K-iSeBa2b}V4_wY`G1Jsfz4 z9|SdB^;|I8E8gWqHKx!vj_@SMY^hLEIbSMCuE?WKq=c2mJK z8LoG-pnY!uhqFv&L?yEuxo{dpMTsmCn)95xanqBrNPTgXP((H$9N${Ow~Is-FBg%h z53;|Y5$MUN)9W2HBe2TD`ct^LHI<(xWrw}$qSoei?}s)&w$;&!14w6B6>Yr6Y8b)S z0r71`WmAvJJ`1h&poLftLUS6Ir zC$bG9!Im_4Zjse)#K=oJM9mHW1{%l8sz$1o?ltdKlLTxWWPB>Vk22czVt|1%^wnN@*!l)}?EgtvhC>vlHm^t+ogpgHI1_$1ox9e;>0!+b(tBrmXRB`PY1vp-R**8N7 zGP|QqI$m(Rdu#=(?!(N}G9QhQ%o!aXE=aN{&wtGP8|_qh+7a_j_sU5|J^)vxq;# zjvzLn%_QPHZZIWu1&mRAj;Sa_97p_lLq_{~j!M9N^1yp3U_SxRqK&JnR%6VI#^E12 z>CdOVI^_9aPK2eZ4h&^{pQs}xsijXgFYRIxJ~N7&BB9jUR1fm!(xl)mvy|3e6-B3j zJn#ajL;bFTYJ2+Q)tDjx=3IklO@Q+FFM}6UJr6km7hj7th9n_&JR7fnqC!hTZoM~T zBeaVFp%)0cbPhejX<8pf5HyRUj2>aXnXBqDJe73~J%P(2C?-RT{c3NjE`)om! zl$uewSgWkE66$Kb34+QZZvRn`fob~Cl9=cRk@Es}KQm=?E~CE%spXaMO6YmrMl%9Q zlA3Q$3|L1QJ4?->UjT&CBd!~ru{Ih^in&JXO=|<6J!&qp zRe*OZ*cj5bHYlz!!~iEKcuE|;U4vN1rk$xq6>bUWD*u(V@8sG^7>kVuo(QL@Ki;yL zWC!FT(q{E8#on>%1iAS0HMZDJg{Z{^!De(vSIq&;1$+b)oRMwA3nc3mdTSG#3uYO_ z>+x;7p4I;uHz?ZB>dA-BKl+t-3IB!jBRgdvAbW!aJ(Q{aT>+iz?91`C-xbe)IBoND z9_Xth{6?(y3rddwY$GD65IT#f3<(0o#`di{sh2gm{dw*#-Vnc3r=4==&PU^hCv$qd zjw;>i&?L*Wq#TxG$mFIUf>eK+170KG;~+o&1;Tom9}}mKo23KwdEM6UonXgc z!6N(@k8q@HPw{O8O!lAyi{rZv|DpgfU{py+j(X_cwpKqcalcqKIr0kM^%Br3SdeD> zHSKV94Yxw;pjzDHo!Q?8^0bb%L|wC;4U^9I#pd5O&eexX+Im{ z?jKnCcsE|H?{uGMqVie_C~w7GX)kYGWAg%-?8|N_1#W-|4F)3YTDC+QSq1s!DnOML3@d`mG%o2YbYd#jww|jD$gotpa)kntakp#K;+yo-_ZF9qrNZw<%#C zuPE@#3RocLgPyiBZ+R_-FJ_$xP!RzWm|aN)S+{$LY9vvN+IW~Kf3TsEIvP+B9Mtm! zpfNNxObWQpLoaO&cJh5>%slZnHl_Q~(-Tfh!DMz(dTWld@LG1VRF`9`DYKhyNv z2pU|UZ$#_yUx_B_|MxUq^glT}O5Xt(Vm4Mr02><%C)@v;vPb@pT$*yzJ4aPc_FZ3z z3}PLoMBIM>q_9U2rl^sGhk1VUJ89=*?7|v`{!Z{6bqFMq(mYiA?%KbsI~JwuqVA9$H5vDE+VocjX+G^%bieqx->s;XWlKcuv(s%y%D5Xbc9+ zc(_2nYS1&^yL*ey664&4`IoOeDIig}y-E~_GS?m;D!xv5-xwz+G`5l6V+}CpeJDi^ z%4ed$qowm88=iYG+(`ld5Uh&>Dgs4uPHSJ^TngXP_V6fPyl~>2bhi20QB%lSd#yYn zO05?KT1z@?^-bqO8Cg`;ft>ilejsw@2%RR7;`$Vs;FmO(Yr3Fp`pHGr@P2hC%QcA|X&N2Dn zYf`MqXdHi%cGR@%y7Rg7?d3?an){s$zA{!H;Ie5exE#c~@NhQUFG8V=SQh%UxUeiV zd7#UcYqD=lk-}sEwlpu&H^T_V0{#G?lZMxL7ih_&{(g)MWBnCZxtXg znr#}>U^6!jA%e}@Gj49LWG@*&t0V>Cxc3?oO7LSG%~)Y5}f7vqUUnQ;STjdDU}P9IF9d9<$;=QaXc zL1^X7>fa^jHBu_}9}J~#-oz3Oq^JmGR#?GO7b9a(=R@fw@}Q{{@`Wy1vIQ#Bw?>@X z-_RGG@wt|%u`XUc%W{J z>iSeiz8C3H7@St3mOr_mU+&bL#Uif;+Xw-aZdNYUpdf>Rvu0i0t6k*}vwU`XNO2he z%miH|1tQ8~ZK!zmL&wa3E;l?!!XzgV#%PMVU!0xrDsNNZUWKlbiOjzH-1Uoxm8E#r`#2Sz;-o&qcqB zC-O_R{QGuynW14@)7&@yw1U}uP(1cov)twxeLus0s|7ayrtT8c#`&2~Fiu2=R;1_4bCaD=*E@cYI>7YSnt)nQc zohw5CsK%m?8Ack)qNx`W0_v$5S}nO|(V|RZKBD+btO?JXe|~^Qqur%@eO~<8-L^9d z=GA3-V14ng9L29~XJ>a5k~xT2152zLhM*@zlp2P5Eu}bywkcqR;ISbas&#T#;HZSf z2m69qTV(V@EkY(1Dk3`}j)JMo%ZVJ*5eB zYOjIisi+igK0#yW*gBGj?@I{~mUOvRFQR^pJbEbzFxTubnrw(Muk%}jI+vXmJ;{Q6 zrSobKD>T%}jV4Ub?L1+MGOD~0Ir%-`iTnWZN^~YPrcP5y3VMAzQ+&en^VzKEb$K!Q z<7Dbg&DNXuow*eD5yMr+#08nF!;%4vGrJI++5HdCFcGLfMW!KS*Oi@=7hFwDG!h2< zPunUEAF+HncQkbfFj&pbzp|MU*~60Z(|Ik%Tn{BXMN!hZOosNIseT?R;A`W?=d?5X zK(FB=9mZusYahp|K-wyb={rOpdn=@;4YI2W0EcbMKyo~-#^?h`BA9~o285%oY zfifCh5Lk$SY@|2A@a!T2V+{^!psQkx4?x0HSV`(w9{l75QxMk!)U52Lbhn{8ol?S) zCKo*7R(z!uk<6*qO=wh!Pul{(qq6g6xW;X68GI_CXp`XwO zxuSgPRAtM8K7}5E#-GM!*ydOOG_{A{)hkCII<|2=ma*71ci_-}VPARm3crFQjLYV! z9zbz82$|l01mv`$WahE2$=fAGWkd^X2kY(J7iz}WGS z@%MyBEO=A?HB9=^?nX`@nh;7;laAjs+fbo!|K^mE!tOB>$2a_O0y-*uaIn8k^6Y zSbuv;5~##*4Y~+y7Z5O*3w4qgI5V^17u*ZeupVGH^nM&$qmAk|anf*>r zWc5CV;-JY-Z@Uq1Irpb^O`L_7AGiqd*YpGUShb==os$uN3yYvb`wm6d=?T*it&pDk zo`vhw)RZX|91^^Wa_ti2zBFyWy4cJu#g)_S6~jT}CC{DJ_kKpT`$oAL%b^!2M;JgT zM3ZNbUB?}kP(*YYvXDIH8^7LUxz5oE%kMhF!rnPqv!GiY0o}NR$OD=ITDo9r%4E>E0Y^R(rS^~XjWyVI6 zMOR5rPXhTp*G*M&X#NTL`Hu*R+u*QNoiOKg4CtNPrjgH>c?Hi4MUG#I917fx**+pJfOo!zFM&*da&G_x)L(`k&TPI*t3e^{crd zX<4I$5nBQ8Ax_lmNRa~E*zS-R0sxkz`|>7q_?*e%7bxqNm3_eRG#1ae3gtV9!fQpY z+!^a38o4ZGy9!J5sylDxZTx$JmG!wg7;>&5H1)>f4dXj;B+@6tMlL=)cLl={jLMxY zbbf1ax3S4>bwB9-$;SN2?+GULu;UA-35;VY*^9Blx)Jwyb$=U!D>HhB&=jSsd^6yw zL)?a|>GxU!W}ocTC(?-%z3!IUhw^uzc`Vz_g>-tv)(XA#JK^)ZnC|l1`@CdX1@|!| z_9gQ)7uOf?cR@KDp97*>6X|;t@Y`k_N@)aH7gY27)COv^P3ya9I{4z~vUjLR9~z1Z z5=G{mVtKH*&$*t0@}-i_v|3B$AHHYale7>E+jP`ClqG%L{u;*ff_h@)al?RuL7tOO z->;I}>%WI{;vbLP3VIQ^iA$4wl6@0sDj|~112Y4OFjMs`13!$JGkp%b&E8QzJw_L5 zOnw9joc0^;O%OpF$Qp)W1HI!$4BaXX84`%@#^dk^hFp^pQ@rx4g(8Xjy#!X%+X5Jd@fs3amGT`}mhq#L97R>OwT5-m|h#yT_-v@(k$q7P*9X~T*3)LTdzP!*B} z+SldbVWrrwQo9wX*%FyK+sRXTa@O?WM^FGWOE?S`R(0P{<6p#f?0NJvnBia?k^fX2 zNQs7K-?EijgHJY}&zsr;qJ<*PCZUd*x|dD=IQPUK_nn)@X4KWtqoJNHkT?ZWL_hF? zS8lp2(q>;RXR|F;1O}EE#}gCrY~#n^O`_I&?&z5~7N;zL0)3Tup`%)oHMK-^r$NT% zbFg|o?b9w(q@)6w5V%si<$!U<#}s#x@0aX-hP>zwS#9*75VXA4K*%gUc>+yzupTDBOKH8WR4V0pM(HrfbQ&eJ79>HdCvE=F z|J>s;;iDLB^3(9}?biKbxf1$lI!*Z%*0&8UUq}wMyPs_hclyQQi4;NUY+x2qy|0J; zhn8;5)4ED1oHwg+VZF|80<4MrL97tGGXc5Sw$wAI#|2*cvQ=jB5+{AjMiDHmhUC*a zlmiZ`LAuAn_}hftXh;`Kq0zblDk8?O-`tnilIh|;3lZp@F_osJUV9`*R29M?7H{Fy z`nfVEIDIWXmU&YW;NjU8)EJpXhxe5t+scf|VXM!^bBlwNh)~7|3?fWwo_~ZFk(22% zTMesYw+LNx3J-_|DM~`v93yXe=jPD{q;li;5PD?Dyk+b? zo21|XpT@)$BM$%F=P9J19Vi&1#{jM3!^Y&fr&_`toi`XB1!n>sbL%U9I5<7!@?t)~ z;&H%z>bAaQ4f$wIzkjH70;<8tpUoxzKrPhn#IQfS%9l5=Iu))^XC<58D!-O z{B+o5R^Z21H0T9JQ5gNJnqh#qH^na|z92=hONIM~@_iuOi|F>jBh-?aA20}Qx~EpDGElELNn~|7WRXRFnw+Wdo`|# zBpU=Cz3z%cUJ0mx_1($X<40XEIYz(`noWeO+x#yb_pwj6)R(__%@_Cf>txOQ74wSJ z0#F3(zWWaR-jMEY$7C*3HJrohc79>MCUu26mfYN)f4M~4gD`}EX4e}A!U}QV8!S47 z6y-U-%+h`1n`*pQuKE%Av0@)+wBZr9mH}@vH@i{v(m-6QK7Ncf17x_D=)32`FOjjo zg|^VPf5c6-!FxN{25dvVh#fog=NNpXz zfB$o+0jbRkHH{!TKhE709f+jI^$3#v1Nmf80w`@7-5$1Iv_`)W^px8P-({xwb;D0y z7LKDAHgX<84?l!I*Dvi2#D@oAE^J|g$3!)x1Ua;_;<@#l1fD}lqU2_tS^6Ht$1Wl} zBESo7o^)9-Tjuz$8YQSGhfs{BQV6zW7dA?0b(Dbt=UnQs&4zHfe_sj{RJ4uS-vQpC zX;Bbsuju4%!o8?&m4UZU@~ZZjeFF6ex2ss5_60_JS_|iNc+R0GIjH1@Z z=rLT9%B|WWgOrR7IiIwr2=T;Ne?30M!@{%Qf8o`!>=s<2CBpCK_TWc(DX51>e^xh8 z&@$^b6CgOd7KXQV&Y4%}_#uN*mbanXq(2=Nj`L7H7*k(6F8s6{FOw@(DzU`4-*77{ zF+dxpv}%mFpYK?>N_2*#Y?oB*qEKB}VoQ@bzm>ptmVS_EC(#}Lxxx730trt0G)#$b zE=wVvtqOct1%*9}U{q<)2?{+0TzZzP0jgf9*)arV)*e!f`|jgT{7_9iS@e)recI#z zbzolURQ+TOzE!ymqvBY7+5NnAbWxvMLsLTwEbFqW=CPyCsmJ}P1^V30|D5E|p3BC5 z)3|qgw@ra7aXb-wsa|l^in~1_fm{7bS9jhVRkYVO#U{qMp z)Wce+|DJ}4<2gp8r0_xfZpMo#{Hl2MfjLcZdRB9(B(A(f;+4s*FxV{1F|4d`*sRNd zp4#@sEY|?^FIJ;tmH{@keZ$P(sLh5IdOk@k^0uB^BWr@pk6mHy$qf&~rI>P*a;h0C{%oA*i!VjWn&D~O#MxN&f@1Po# zKN+ zrGrkSjcr?^R#nGl<#Q722^wbYcgW@{+6CBS<1@%dPA8HC!~a`jTz<`g_l5N1M@9wn9GOAZ>nqNgq!yOCbZ@1z`U_N`Z>}+1HIZxk*5RDc&rd5{3qjRh8QmT$VyS;jK z;AF+r6XnnCp=wQYoG|rT2@8&IvKq*IB_WvS%nt%e{MCFm`&W*#LXc|HrD?nVBo=(8*=Aq?u$sDA_sC_RPDUiQ+wnIJET8vx$&fxkW~kP9qXKt zozR)@xGC!P)CTkjeWvXW5&@2?)qt)jiYWWBU?AUtzAN}{JE1I)dfz~7$;}~BmQF`k zpn11qmObXwRB8&rnEG*#4Xax3XBkKlw(;tb?Np^i+H8m(Wyz9k{~ogba@laiEk;2! zV*QV^6g6(QG%vX5Um#^sT&_e`B1pBW5yVth~xUs#0}nv?~C#l?W+9Lsb_5)!71rirGvY zTIJ$OPOY516Y|_014sNv+Z8cc5t_V=i>lWV=vNu#!58y9Zl&GsMEW#pPYPYGHQ|;vFvd*9eM==$_=vc7xnyz0~ zY}r??$<`wAO?JQk@?RGvkWVJlq2dk9vB(yV^vm{=NVI8dhsX<)O(#nr9YD?I?(VmQ z^r7VfUBn<~p3()8yOBjm$#KWx!5hRW)5Jl7wY@ky9lNM^jaT##8QGVsYeaVywmpv>X|Xj7gWE1Ezai&wVLt3p)k4w~yrskT-!PR!kiyQlaxl(( zXhF%Q9x}1TMt3~u@|#wWm-Vq?ZerK={8@~&@9r5JW}r#45#rWii};t`{5#&3$W)|@ zbAf2yDNe0q}NEUvq_Quq3cTjcw z@H_;$hu&xllCI9CFDLuScEMg|x{S7GdV8<&Mq=ezDnRZAyX-8gv97YTm0bg=d)(>N z+B2FcqvI9>jGtnK%eO%y zoBPkJTk%y`8TLf4)IXPBn`U|9>O~WL2C~C$z~9|0m*YH<-vg2CD^SX#&)B4ngOSG$ zV^wmy_iQk>dfN@Pv(ckfy&#ak@MLC7&Q6Ro#!ezM*VEh`+b3Jt%m(^T&p&WJ2Oqvj zs-4nq0TW6cv~(YI$n0UkfwN}kg3_fp?(ijSV#tR9L0}l2qjc7W?i*q01=St0eZ=4h zyGQbEw`9OEH>NMuIe)hVwYHsGERWOD;JxEiO7cQv%pFCeR+IyhwQ|y@&^24k+|8fD zLiOWFNJ2&vu2&`Jv96_z-Cd5RLgmeY3*4rDOQo?Jm`;I_(+ejsPM03!ly!*Cu}Cco zrQSrEDHNyzT(D5s1rZq!8#?f6@v6dB7a-aWs(Qk>N?UGAo{gytlh$%_IhyL7h?DLXDGx zgxGEBQoCAWo-$LRvM=F5MTle`M})t3vVv;2j0HZY&G z22^iGhV@uaJh(XyyY%} zd4iH_UfdV#T=3n}(Lj^|n;O4|$;xhu*8T3hR1mc_A}fK}jfZ7LX~*n5+`8N2q#rI$ z@<_2VANlYF$vIH$ zl<)+*tIWW78IIINA7Rr7i{<;#^yzxoLNkXL)eSs=%|P>$YQIh+ea_3k z_s7r4%j7%&*NHSl?R4k%1>Z=M9o#zxY!n8sL5>BO-ZP;T3Gut>iLS@U%IBrX6BA3k z)&@q}V8a{X<5B}K5s(c(LQ=%v1ocr`t$EqqY0EqVjr65usa=0bkf|O#ky{j3)WBR(((L^wmyHRzoWuL2~WTC=`yZ zn%VX`L=|Ok0v7?s>IHg?yArBcync5rG#^+u)>a%qjES%dRZoIyA8gQ;StH z1Ao7{<&}6U=5}4v<)1T7t!J_CL%U}CKNs-0xWoTTeqj{5{?Be$L0_tk>M9o8 zo371}S#30rKZFM{`H_(L`EM9DGp+Mifk&IP|C2Zu_)Ghr4Qtpmkm1osCf@%Z$%t+7 zYH$Cr)Ro@3-QDeQJ8m+x6%;?YYT;k6Z0E-?kr>x33`H%*ueBD7Zx~3&HtWn0?2Wt} zTG}*|v?{$ajzt}xPzV%lL1t-URi8*Zn)YljXNGDb>;!905Td|mpa@mHjIH%VIiGx- zd@MqhpYFu4_?y5N4xiHn3vX&|e6r~Xt> zZG`aGq|yTNjv;9E+Txuoa@A(9V7g?1_T5FzRI;!=NP1Kqou1z5?%X~Wwb{trRfd>i z8&y^H)8YnKyA_Fyx>}RNmQIczT?w2J4SNvI{5J&}Wto|8FR(W;Qw#b1G<1%#tmYzQ zQ2mZA-PAdi%RQOhkHy9Ea#TPSw?WxwL@H@cbkZwIq0B!@ns}niALidmn&W?!Vd4Gj zO7FiuV4*6Mr^2xlFSvM;Cp_#r8UaqIzHJQg_z^rEJw&OMm_8NGAY2)rKvki|o1bH~ z$2IbfVeY2L(^*rMRU1lM5Y_sgrDS`Z??nR2lX;zyR=c%UyGb*%TC-Dil?SihkjrQy~TMv6;BMs7P8il`H7DmpVm@rJ;b)hW)BL)GjS154b*xq-NXq2cwE z^;VP7ua2pxvCmxrnqUYQMH%a%nHmwmI33nJM(>4LznvY*k&C0{8f*%?zggpDgkuz&JBx{9mfb@wegEl2v!=}Sq2Gaty0<)UrOT0{MZtZ~j5y&w zXlYa_jY)I_+VA-^#mEox#+G>UgvM!Ac8zI<%JRXM_73Q!#i3O|)lOP*qBeJG#BST0 zqohi)O!|$|2SeJQo(w6w7%*92S})XfnhrH_Z8qe!G5>CglP=nI7JAOW?(Z29;pXJ9 zR9`KzQ=WEhy*)WH>$;7Cdz|>*i>=##0bB)oU0OR>>N<21e4rMCHDemNi2LD>Nc$;& zQRFthpWniC1J6@Zh~iJCoLOxN`oCKD5Q4r%ynwgUKPlIEd#?QViIqovY|czyK8>6B zSP%{2-<;%;1`#0mG^B(8KbtXF;Nf>K#Di72UWE4gQ%(_26Koiad)q$xRL~?pN71ZZ zujaaCx~jXjygw;rI!WB=xrOJO6HJ!!w}7eiivtCg5K|F6$EXa)=xUC za^JXSX98W`7g-tm@uo|BKj39Dl;sg5ta;4qjo^pCh~{-HdLl6qI9Ix6f$+qiZ$}s= zNguKrU;u+T@ko(Vr1>)Q%h$?UKXCY>3se%&;h2osl2D zE4A9bd7_|^njDd)6cI*FupHpE3){4NQ*$k*cOWZ_?CZ>Z4_fl@n(mMnYK62Q1d@+I zr&O))G4hMihgBqRIAJkLdk(p(D~X{-oBUA+If@B}j& zsHbeJ3RzTq96lB7d($h$xTeZ^gP0c{t!Y0c)aQE;$FY2!mACg!GDEMKXFOPI^)nHZ z`aSPJpvV0|bbrzhWWkuPURlDeN%VT8tndV8?d)eN*i4I@u zVKl^6{?}A?P)Fsy?3oi#clf}L18t;TjNI2>eI&(ezDK7RyqFxcv%>?oxUlonv(px) z$vnPzRH`y5A(x!yOIfL0bmgeMQB$H5wenx~!ujQK*nUBW;@Em&6Xv2%s(~H5WcU2R z;%Nw<$tI)a`Ve!>x+qegJnQsN2N7HaKzrFqM>`6R*gvh%O*-%THt zrB$Nk;lE;z{s{r^PPm5qz(&lM{sO*g+W{sK+m3M_z=4=&CC>T`{X}1Vg2PEfSj2x_ zmT*(x;ov%3F?qoEeeM>dUn$a*?SIGyO8m806J1W1o+4HRhc2`9$s6hM#qAm zChQ87b~GEw{ADfs+5}FJ8+|bIlIv(jT$Ap#hSHoXdd9#w<#cA<1Rkq^*EEkknUd4& zoIWIY)sAswy6fSERVm&!SO~#iN$OgOX*{9@_BWFyJTvC%S++ilSfCrO(?u=Dc?CXZ zzCG&0yVR{Z`|ZF0eEApWEo#s9osV>F{uK{QA@BES#&;#KsScf>y zvs?vIbI>VrT<*!;XmQS=bhq%46-aambZ(8KU-wOO2=en~D}MCToB_u;Yz{)1ySrPZ z@=$}EvjTdzTWU7c0ZI6L8=yP+YRD_eMMos}b5vY^S*~VZysrkq<`cK3>>v%uy7jgq z0ilW9KjVDHLv0b<1K_`1IkbTOINs0=m-22c%M~l=^S}%hbli-3?BnNq?b`hx^HX2J zIe6ECljRL0uBWb`%{EA=%!i^4sMcj+U_TaTZRb+~GOk z^ZW!nky0n*Wb*r+Q|9H@ml@Z5gU&W`(z4-j!OzC1wOke`TRAYGZVl$PmQ16{3196( zO*?`--I}Qf(2HIwb2&1FB^!faPA2=sLg(@6P4mN)>Dc3i(B0;@O-y2;lM4akD>@^v z=u>*|!s&9zem70g7zfw9FXl1bpJW(C#5w#uy5!V?Q(U35A~$dR%LDVnq@}kQm13{} zd53q3N(s$Eu{R}k2esbftfjfOITCL;jWa$}(mmm}d(&7JZ6d3%IABCapFFYjdEjdK z&4Edqf$G^MNAtL=uCDRs&Fu@FXRgX{*0<(@c3|PNHa>L%zvxWS={L8%qw`STm+=Rd zA}FLspESSIpE_^41~#5yI2bJ=9`oc;GIL!JuW&7YetZ?0H}$$%8rW@*J37L-~Rsx!)8($nI4 zZhcZ2^=Y+p4YPl%j!nFJA|*M^gc(0o$i3nlphe+~-_m}jVkRN{spFs(o0ajW@f3K{ zDV!#BwL322CET$}Y}^0ixYj2w>&Xh12|R8&yEw|wLDvF!lZ#dOTHM9pK6@Nm-@9Lnng4ZHBgBSrr7KI8YCC9DX5Kg|`HsiwJHg2(7#nS;A{b3tVO?Z% za{m5b3rFV6EpX;=;n#wltDv1LE*|g5pQ+OY&*6qCJZc5oDS6Z6JD#6F)bWxZSF@q% z+1WV;m!lRB!n^PC>RgQCI#D1br_o^#iPk>;K2hB~0^<~)?p}LG%kigm@moD#q3PE+ zA^Qca)(xnqw6x>XFhV6ku9r$E>bWNrVH9fum0?4s?Rn2LG{Vm_+QJHse6xa%nzQ?k zKug4PW~#Gtb;#5+9!QBgyB@q=sk9=$S{4T>wjFICStOM?__fr+Kei1 z3j~xPqW;W@YkiUM;HngG!;>@AITg}vAE`M2Pj9Irl4w1fo4w<|Bu!%rh%a(Ai^Zhi zs92>v5;@Y(Zi#RI*ua*h`d_7;byQSa*v9E{2x$<-_=5Z<7{%)}4XExANcz@rK69T0x3%H<@frW>RA8^swA+^a(FxK| zFl3LD*ImHN=XDUkrRhp6RY5$rQ{bRgSO*(vEHYV)3Mo6Jy3puiLmU&g82p{qr0F?ohmbz)f2r{X2|T2 z$4fdQ=>0BeKbiVM!e-lIIs8wVTuC_m7}y4A_%ikI;Wm5$9j(^Y z(cD%U%k)X>_>9~t8;pGzL6L-fmQO@K; zo&vQzMlgY95;1BSkngY)e{`n0!NfVgf}2mB3t}D9@*N;FQ{HZ3Pb%BK6;5#-O|WI( zb6h@qTLU~AbVW#_6?c!?Dj65Now7*pU{h!1+eCV^KCuPAGs28~3k@ueL5+u|Z-7}t z9|lskE`4B7W8wMs@xJa{#bsCGDFoRSNSnmNYB&U7 zVGKWe%+kFB6kb)e;TyHfqtU6~fRg)f|>=5(N36)0+C z`hv65J<$B}WUc!wFAb^QtY31yNleq4dzmG`1wHTj=c*=hay9iD071Hc?oYoUk|M*_ zU1GihAMBsM@5rUJ(qS?9ZYJ6@{bNqJ`2Mr+5#hKf?doa?F|+^IR!8lq9)wS3tF_9n zW_?hm)G(M+MYb?V9YoX^_mu5h-LP^TL^!Q9Z7|@sO(rg_4+@=PdI)WL(B7`!K^ND- z-uIuVDCVEdH_C@c71YGYT^_Scf_dhB8Z2Xy6vGtBSlYud9vggOqv^L~F{BraSE_t} zIkP+Hp2&nH^-MNEs}^`oMLy11`PQW$T|K(`Bu*(f@)mv1-qY(_YG&J2M2<7k;;RK~ zL{Fqj9yCz8(S{}@c)S!65aF<=&eLI{hAMErCx&>i7OeDN>okvegO87OaG{Jmi<|}D zaT@b|0X{d@OIJ7zvT>r+eTzgLq~|Dpu)Z&db-P4z*`M$UL51lf>FLlq6rfG)%doyp z)3kk_YIM!03eQ8Vu_2fg{+osaEJPtJ-s36R+5_AEG12`NG)IQ#TF9c@$99%0iye+ zUzZ57=m2)$D(5Nx!n)=5Au&O0BBgwxIBaeI(mro$#&UGCr<;C{UjJVAbVi%|+WP(a zL$U@TYCxJ=1{Z~}rnW;7UVb7+ZnzgmrogDxhjLGo>c~MiJAWs&&;AGg@%U?Y^0JhL ze(x6Z74JG6FlOFK(T}SXQfhr}RIFl@QXKnIcXYF)5|V~e-}suHILKT-k|<*~Ij|VF zC;t@=uj=hot~*!C68G8hTA%8SzOfETOXQ|3FSaIEjvBJp(A)7SWUi5!Eu#yWgY+;n zlm<$+UDou*V+246_o#V4kMdto8hF%%Lki#zPh}KYXmMf?hrN0;>Mv%`@{0Qn`Ujp) z=lZe+13>^Q!9zT);H<(#bIeRWz%#*}sgUX9P|9($kexOyKIOc`dLux}c$7It4u|Rl z6SSkY*V~g_B-hMPo_ak>>z@AVQ(_N)VY2kB3IZ0G(iDUYw+2d7W^~(Jq}KY=JnWS( z#rzEa&0uNhJ>QE8iiyz;n2H|SV#Og+wEZv=f2%1ELX!SX-(d3tEj$5$1}70Mp<&eI zCkfbByL7af=qQE@5vDVxx1}FSGt_a1DoE3SDI+G)mBAna)KBG4p8Epxl9QZ4BfdAN zFnF|Y(umr;gRgG6NLQ$?ZWgllEeeq~z^ZS7L?<(~O&$5|y)Al^iMKy}&W+eMm1W z7EMU)u^ke(A1#XCV>CZ71}P}0x)4wtHO8#JRG3MA-6g=`ZM!FcICCZ{IEw8Dm2&LQ z1|r)BUG^0GzI6f946RrBlfB1Vs)~8toZf~7)+G;pv&XiUO(%5bm)pl=p>nV^o*;&T z;}@oZSibzto$arQgfkp|z4Z($P>dTXE{4O=vY0!)kDO* zGF8a4wq#VaFpLfK!iELy@?-SeRrdz%F*}hjKcA*y@mj~VD3!it9lhRhX}5YOaR9$} z3mS%$2Be7{l(+MVx3 z(4?h;P!jnRmX9J9sYN#7i=iyj_5q7n#X(!cdqI2lnr8T$IfOW<_v`eB!d9xY1P=2q&WtOXY=D9QYteP)De?S4}FK6#6Ma z=E*V+#s8>L;8aVroK^6iKo=MH{4yEZ_>N-N z`(|;aOATba1^asjxlILk<4}f~`39dBFlxj>Dw(hMYKPO3EEt1@S`1lxFNM+J@uB7T zZ8WKjz7HF1-5&2=l=fqF-*@>n5J}jIxdDwpT?oKM3s8Nr`x8JnN-kCE?~aM1H!hAE z%%w(3kHfGwMnMmNj(SU(w42OrC-euI>Dsjk&jz3ts}WHqmMpzQ3vZrsXrZ|}+MHA7 z068obeXZTsO*6RS@o3x80E4ok``rV^Y3hr&C1;|ZZ0|*EKO`$lECUYG2gVFtUTw)R z4Um<0ZzlON`zTdvVdL#KFoMFQX*a5wM0Czp%wTtfK4Sjs)P**RW&?lP$(<}q%r68Z zS53Y!d@&~ne9O)A^tNrXHhXBkj~$8j%pT1%%mypa9AW5E&s9)rjF4@O3ytH{0z6riz|@< zB~UPh*wRFg2^7EbQrHf0y?E~dHlkOxof_a?M{LqQ^C!i2dawHTPYUE=X@2(3<=OOxs8qn_(y>pU>u^}3y&df{JarR0@VJn0f+U%UiF=$Wyq zQvnVHESil@d|8&R<%}uidGh7@u^(%?$#|&J$pvFC-n8&A>utA=n3#)yMkz+qnG3wd zP7xCnF|$9Dif@N~L)Vde3hW8W!UY0BgT2v(wzp;tlLmyk2%N|0jfG$%<;A&IVrOI< z!L)o>j>;dFaqA3pL}b-Je(bB@VJ4%!JeX@3x!i{yIeIso^=n?fDX`3bU=eG7sTc%g%ye8$v8P@yKE^XD=NYxTb zbf!Mk=h|otpqjFaA-vs5YOF-*GwWPc7VbaOW&stlANnCN8iftFMMrUdYNJ_Bnn5Vt zxfz@Ah|+4&P;reZxp;MmEI7C|FOv8NKUm8njF7Wb6Gi7DeODLl&G~}G4be&*Hi0Qw z5}77vL0P+7-B%UL@3n1&JPxW^d@vVwp?u#gVcJqY9#@-3X{ok#UfW3<1fb%FT`|)V~ggq z(3AUoUS-;7)^hCjdT0Kf{i}h)mBg4qhtHHBti=~h^n^OTH5U*XMgDLIR@sre`AaB$ zg)IGBET_4??m@cx&c~bA80O7B8CHR7(LX7%HThkeC*@vi{-pL%e)yXp!B2InafbDF zjPXf1mko3h59{lT6EEbxKO1Z5GF71)WwowO6kY|6tjSVSWdQ}NsK2x{>i|MKZK8%Q zfu&_0D;CO-Jg0#YmyfctyJ!mRJp)e#@O0mYdp|8x;G1%OZQ3Q847YWTyy|%^cpA;m zze0(5p{tMu^lDkpe?HynyO?a1$_LJl2L&mpeKu%8YvgRNr=%2z${%WThHG=vrWY@4 zsA`OP#O&)TetZ>s%h!=+CE15lOOls&nvC~$Qz0Ph7tHiP;O$i|eDwpT{cp>+)0-|; zY$|bB+Gbel>5aRN3>c0x)4U=|X+z+{ zn*_p*EQoquRL+=+p;=lm`d71&1NqBz&_ph)MXu(Nv6&XE7(RsS)^MGj5Q?Fwude-(sq zjJ>aOq!7!EN>@(fK7EE#;i_BGvli`5U;r!YA{JRodLBc6-`n8K+Fjgwb%sX;j=qHQ z7&Tr!)!{HXoO<2BQrV9Sw?JRaLXV8HrsNevvnf>Y-6|{T!pYLl7jp$-nEE z#X!4G4L#K0qG_4Z;Cj6=;b|Be$hi4JvMH!-voxqx^@8cXp`B??eFBz2lLD8RRaRGh zn7kUfy!YV~p(R|p7iC1Rdgt$_24i0cd-S8HpG|`@my70g^y`gu%#Tf_L21-k?sRRZHK&at(*ED0P8iw{7?R$9~OF$Ko;Iu5)ur5<->x!m93Eb zFYpIx60s=Wxxw=`$aS-O&dCO_9?b1yKiPCQmSQb>T)963`*U+Ydj5kI(B(B?HNP8r z*bfSBpSu)w(Z3j7HQoRjUG(+d=IaE~tv}y14zHHs|0UcN52fT8V_<@2ep_ee{QgZG zmgp8iv4V{k;~8@I%M3<#B;2R>Ef(Gg_cQM7%}0s*^)SK6!Ym+~P^58*wnwV1BW@eG z4sZLqsUvBbFsr#8u7S1r4teQ;t)Y@jnn_m5jS$CsW1um!p&PqAcc8!zyiXHVta9QC zY~wCwCF0U%xiQPD_INKtTb;A|Zf29(mu9NI;E zc-e>*1%(LSXB`g}kd`#}O;veb<(sk~RWL|f3ljxCnEZDdNSTDV6#Td({6l&y4IjKF z^}lIUq*ZUqgTPumD)RrCN{M^jhY>E~1pn|KOZ5((%F)G|*ZQ|r4zIbrEiV%42hJV8 z3xS)=!X1+=olbdGJ=yZil?oXLct8FM{(6ikLL3E%=q#O6(H$p~gQu6T8N!plf!96| z&Q3=`L~>U0zZh;z(pGR2^S^{#PrPxTRHD1RQOON&f)Siaf`GLj#UOk&(|@0?zm;Sx ztsGt8=29-MZs5CSf1l1jNFtNt5rFNZxJPvkNu~2}7*9468TWm>nN9TP&^!;J{-h)_ z7WsHH9|F%I`Pb!>KAS3jQWKfGivTVkMJLO-HUGM_a4UQ_%RgL6WZvrW+Z4ujZn;y@ zz9$=oO!7qVTaQAA^BhX&ZxS*|5dj803M=k&2%QrXda`-Q#IoZL6E(g+tN!6CA!CP* zCpWtCujIea)ENl0liwVfj)Nc<9mV%+e@=d`haoZ*`B7+PNjEbXBkv=B+Pi^~L#EO$D$ZqTiD8f<5$eyb54-(=3 zh)6i8i|jp(@OnRrY5B8t|LFXFQVQ895n*P16cEKTrT*~yLH6Z4e*bZ5otpRDri&+A zfNbK1D5@O=sm`fN=WzWyse!za5n%^+6dHPGX#8DyIK>?9qyX}2XvBWVqbP%%D)7$= z=#$WulZlZR<{m#gU7lwqK4WS1Ne$#_P{b17qe$~UOXCl>5b|6WVh;5vVnR<%d+Lnp z$uEmML38}U4vaW8>shm6CzB(Wei3s#NAWE3)a2)z@i{4jTn;;aQS)O@l{rUM`J@K& l00vQ5JBs~;vo!vr%%-k{2_Fq1Mn4QF81S)AQ99zk{{c4yR+0b! literal 0 HcmV?d00001 diff --git a/extension/android_test/gradle/wrapper/gradle-wrapper.properties b/extension/android_test/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000..a80b22ce5c --- /dev/null +++ b/extension/android_test/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/extension/android_test/gradlew b/extension/android_test/gradlew new file mode 100755 index 0000000000..1aa94a4269 --- /dev/null +++ b/extension/android_test/gradlew @@ -0,0 +1,249 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/extension/android_test/gradlew.bat b/extension/android_test/gradlew.bat new file mode 100644 index 0000000000..25da30dbde --- /dev/null +++ b/extension/android_test/gradlew.bat @@ -0,0 +1,92 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/extension/android_test/settings.gradle b/extension/android_test/settings.gradle new file mode 100644 index 0000000000..6b1bd4f7f8 --- /dev/null +++ b/extension/android_test/settings.gradle @@ -0,0 +1,24 @@ +/* + * This file was generated by the Gradle 'init' task. + * + * The settings file is used to specify which projects to include in your build. + * For more detailed information on multi-project builds, please refer to https://docs.gradle.org/8.6/userguide/multi_project_builds.html in the Gradle documentation. + */ +pluginManagement { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } +} + +dependencyResolutionManagement { + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) + repositories { + google() + mavenCentral() + } +} + +rootProject.name = 'executorch' +include('src') diff --git a/extension/android_test/setup.sh b/extension/android_test/setup.sh new file mode 100755 index 0000000000..a12f76c1f3 --- /dev/null +++ b/extension/android_test/setup.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -eu + +BUILD_AAR_DIR="$(mktemp -d)" +export BUILD_AAR_DIR + +BASEDIR=$(dirname "$0") +source "$BASEDIR"/../../build/build_android_llm_demo.sh + +build_native_library() { + ANDROID_ABI="$1" + CMAKE_OUT="cmake-out-android-${ANDROID_ABI}" + EXECUTORCH_CMAKE_BUILD_TYPE="${EXECUTORCH_CMAKE_BUILD_TYPE:-Release}" + cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ + -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ + -DANDROID_ABI="${ANDROID_ABI}" \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -B"${CMAKE_OUT}" + + cmake --build "${CMAKE_OUT}" -j16 --target install + + cmake extension/android \ + -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}"/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI="${ANDROID_ABI}" \ + -DCMAKE_INSTALL_PREFIX=c"${CMAKE_OUT}" \ + -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ + -B"${CMAKE_OUT}"/extension/android + + cmake --build "${CMAKE_OUT}"/extension/android -j16 + + # Copy artifacts to ABI specific directory + mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" + cp "${CMAKE_OUT}"/extension/android/*.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/" +} + +pushd "$BASEDIR"/../../ +build_jar +build_native_library "arm64-v8a" +build_native_library "x86_64" +build_aar +popd +mkdir -p "$BASEDIR"/src/libs +cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/src/libs/executorch.aar diff --git a/extension/android_test/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java b/extension/android_test/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java new file mode 100644 index 0000000000..e8259969ab --- /dev/null +++ b/extension/android_test/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import android.os.Environment; +import androidx.test.rule.GrantPermissionRule; +import android.Manifest; +import android.content.Context; +import org.junit.Test; +import org.junit.Before; +import org.junit.Rule; +import org.junit.runner.RunWith; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.io.IOException; +import java.io.File; +import java.io.FileOutputStream; +import org.junit.runners.JUnit4; +import org.apache.commons.io.FileUtils; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.InstrumentationRegistry; +import org.pytorch.executorch.Module; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Tensor; + +/** Unit tests for {@link Module}. */ +@RunWith(AndroidJUnit4.class) +public class ModuleInstrumentationTest { + private static String TEST_FILE_NAME = "/add.pte"; + private static String MISSING_FILE_NAME = "/missing.pte"; + private static String NON_PTE_FILE_NAME = "/test.txt"; + private static String FORWARD_METHOD = "forward"; + private static String NONE_METHOD = "none"; + private static int OK = 0x00; + private static int INVALID_ARGUMENT = 0x12; + private static int ACCESS_FAILED = 0x22; + + private static String getTestFilePath(String fileName) { + return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName; + } + + @Before + public void setUp() throws IOException { + // copy zipped test resources to local device + File addPteFile = new File(getTestFilePath(TEST_FILE_NAME)); + InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, addPteFile); + inputStream.close(); + + File nonPteFile = new File(getTestFilePath(NON_PTE_FILE_NAME)); + inputStream = getClass().getResourceAsStream(NON_PTE_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, nonPteFile); + inputStream.close(); + } + + @Rule + public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE); + + @Test + public void testModuleLoadAndForward() throws IOException, URISyntaxException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + EValue[] results = module.forward(); + assertTrue(results[0].isTensor()); + } + + @Test + public void testModuleLoadMethodAndForward() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, OK); + + EValue[] results = module.forward(); + assertTrue(results[0].isTensor()); + } + + @Test + public void testModuleLoadForwardExplicit() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + EValue[] results = module.execute(FORWARD_METHOD); + assertTrue(results[0].isTensor()); + } + + @Test + public void testModuleLoadNonExistantFile() throws IOException{ + Module module = Module.load(getTestFilePath(MISSING_FILE_NAME)); + + EValue[] results = module.forward(); + assertEquals(null, results); + } + + @Test + public void testModuleLoadMethodNonExistantFile() throws IOException{ + Module module = Module.load(getTestFilePath(MISSING_FILE_NAME)); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, ACCESS_FAILED); + } + + @Test + public void testModuleLoadMethodNonExistantMethod() throws IOException{ + Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); + + int loadMethod = module.loadMethod(NONE_METHOD); + assertEquals(loadMethod, INVALID_ARGUMENT); + } + + @Test + public void testNonPteFile() throws IOException{ + Module module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)); + + int loadMethod = module.loadMethod(FORWARD_METHOD); + assertEquals(loadMethod, INVALID_ARGUMENT); + } +} diff --git a/extension/android_test/src/androidTest/resources/test.txt b/extension/android_test/src/androidTest/resources/test.txt new file mode 100644 index 0000000000..039461e6a9 --- /dev/null +++ b/extension/android_test/src/androidTest/resources/test.txt @@ -0,0 +1 @@ +non pte file diff --git a/extension/android_test/src/main/AndroidManifest.xml b/extension/android_test/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..b8ac862938 --- /dev/null +++ b/extension/android_test/src/main/AndroidManifest.xml @@ -0,0 +1,12 @@ + + + + + + + + + From 2660287a9176a3797a2ba3e6997e597f734bda1e Mon Sep 17 00:00:00 2001 From: David Lin Date: Mon, 11 Nov 2024 12:44:20 -0800 Subject: [PATCH 49/59] added instrumentation test for LlamaModule (#6759) Added instrumentation test for LlamaModule. Modified setup.sh to include building stories110M model and moves it into src/androidTest/resources Added test cases for LlamaModule by generating a sequence length of 32, and verifying the length. Also verifies that stop() works by checking output length is less than input sequence length [ghstack-poisoned] --- extension/android_test/setup.sh | 6 + .../LlamaModuleInstrumentationTest.java | 119 ++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java diff --git a/extension/android_test/setup.sh b/extension/android_test/setup.sh index a12f76c1f3..fff4bdf60e 100755 --- a/extension/android_test/setup.sh +++ b/extension/android_test/setup.sh @@ -21,10 +21,13 @@ build_native_library() { -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI="${ANDROID_ABI}" \ -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ -B"${CMAKE_OUT}" cmake --build "${CMAKE_OUT}" -j16 --target install @@ -33,6 +36,7 @@ build_native_library() { -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}"/build/cmake/android.toolchain.cmake \ -DANDROID_ABI="${ANDROID_ABI}" \ -DCMAKE_INSTALL_PREFIX=c"${CMAKE_OUT}" \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ -B"${CMAKE_OUT}"/extension/android @@ -48,6 +52,8 @@ build_jar build_native_library "arm64-v8a" build_native_library "x86_64" build_aar +source ".ci/scripts/test_llama.sh" stories110M cmake fp16 portable ${BUILD_AAR_DIR} popd mkdir -p "$BASEDIR"/src/libs cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/src/libs/executorch.aar +unzip -o "$BUILD_AAR_DIR"/model.zip -d "$BASEDIR"/src/androidTest/resources diff --git a/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java b/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java new file mode 100644 index 0000000000..940e34d684 --- /dev/null +++ b/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import android.os.Environment; +import androidx.test.rule.GrantPermissionRule; +import android.Manifest; +import android.content.Context; +import org.junit.Test; +import org.junit.Before; +import org.junit.Rule; +import org.junit.runner.RunWith; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.ArrayList; +import java.io.IOException; +import java.io.File; +import java.io.FileOutputStream; +import org.junit.runners.JUnit4; +import org.apache.commons.io.FileUtils; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.InstrumentationRegistry; +import org.pytorch.executorch.LlamaModule; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.Module; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Tensor; + +/** Unit tests for {@link LlamaModule}. */ +@RunWith(AndroidJUnit4.class) +public class LlamaModuleInstrumentationTest implements LlamaCallback { + private static String TEST_FILE_NAME = "/tinyllama_portable_fp16_h.pte"; + private static String TOKENIZER_FILE_NAME = "/tokenizer.bin"; + private static String TEST_PROMPT = "Hello"; + private static int OK = 0x00; + private static int SEQ_LEN = 32; + + private final List results = new ArrayList<>(); + private final List tokensPerSecond = new ArrayList<>(); + private LlamaModule mModule; + + private static String getTestFilePath(String fileName) { + return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName; + } + + @Before + public void setUp() throws IOException { + // copy zipped test resources to local device + File addPteFile = new File(getTestFilePath(TEST_FILE_NAME)); + InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, addPteFile); + inputStream.close(); + + File tokenizerFile = new File(getTestFilePath(TOKENIZER_FILE_NAME)); + inputStream = getClass().getResourceAsStream(TOKENIZER_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, tokenizerFile); + inputStream.close(); + + mModule = new LlamaModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f); + } + + @Rule + public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE); + + @Test + public void testGenerate() throws IOException, URISyntaxException{ + int loadResult = mModule.load(); + // Check that the model can be load successfully + assertEquals(OK, loadResult); + + mModule.generate(TEST_PROMPT, SEQ_LEN, LlamaModuleInstrumentationTest.this); + assertEquals(results.size(), SEQ_LEN); + assertTrue(tokensPerSecond.get(tokensPerSecond.size() - 1) > 0); + } + + @Test + public void testGenerateAndStop() throws IOException, URISyntaxException{ + int seqLen = 32; + mModule.generate(TEST_PROMPT, SEQ_LEN, new LlamaCallback() { + @Override + public void onResult(String result) { + LlamaModuleInstrumentationTest.this.onResult(result); + mModule.stop(); + } + + @Override + public void onStats(float tps) { + LlamaModuleInstrumentationTest.this.onStats(tps); + } + }); + + int stoppedResultSize = results.size(); + assertTrue(stoppedResultSize < SEQ_LEN); + } + + @Override + public void onResult(String result) { + results.add(result); + } + + @Override + public void onStats(float tps) { + tokensPerSecond.add(tps); + } +} From d544f94929cedbf79b8fa1e6605d5b33aba3f9dc Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 11 Nov 2024 12:46:41 -0800 Subject: [PATCH 50/59] [ET-VK] Statically link MoltenVK (#6762) TSIA Differential Revision: [D65769129](https://our.internmc.facebook.com/intern/diff/D65769129/) ghstack-source-id: 252926807 Pull Request resolved: https://github.com/pytorch/executorch/pull/6757 Co-authored-by: Stephen Jia --- backends/vulkan/targets.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 2c4671afa0..c2b46774aa 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -118,7 +118,7 @@ def define_common_targets(is_fbcode = False): "fbsource//third-party/toolchains:android" ], "ovr_config//os:macos-arm64": [ - "//third-party/khronos:moltenVK" + "//third-party/khronos:moltenVK_static" ], }) VK_API_PREPROCESSOR_FLAGS += select({ From c411a75dc8d3243520ecb96bd62c6a4ab7aecaf1 Mon Sep 17 00:00:00 2001 From: David Lin Date: Mon, 11 Nov 2024 13:00:06 -0800 Subject: [PATCH 51/59] add script to generate add.pte (#6760) * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- extension/android_test/add_model.py | 26 ++++++++++++++++++++++++++ extension/android_test/setup.sh | 2 ++ 2 files changed, 28 insertions(+) create mode 100644 extension/android_test/add_model.py diff --git a/extension/android_test/add_model.py b/extension/android_test/add_model.py new file mode 100644 index 0000000000..5c7cf4770e --- /dev/null +++ b/extension/android_test/add_model.py @@ -0,0 +1,26 @@ +import torch +from executorch.exir import to_edge +from torch.export import export + + +# Start with a PyTorch model that adds two input tensors (matrices) +class Add(torch.nn.Module): + def __init__(self): + super(Add, self).__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + y + + +# 1. torch.export: Defines the program with the ATen operator set. +aten_dialect = export(Add(), (torch.ones(1), torch.ones(1))) + +# 2. to_edge: Make optimizations for Edge devices +edge_program = to_edge(aten_dialect) + +# 3. to_executorch: Convert the graph to an ExecuTorch program +executorch_program = edge_program.to_executorch() + +# 4. Save the compiled .pte program +with open("add.pte", "wb") as file: + file.write(executorch_program.buffer) diff --git a/extension/android_test/setup.sh b/extension/android_test/setup.sh index fff4bdf60e..d83aeeebb4 100755 --- a/extension/android_test/setup.sh +++ b/extension/android_test/setup.sh @@ -56,4 +56,6 @@ source ".ci/scripts/test_llama.sh" stories110M cmake fp16 portable ${BUILD_AAR_D popd mkdir -p "$BASEDIR"/src/libs cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/src/libs/executorch.aar +python add_model.py +mv "add.pte" "$BASEDIR"/src/androidTest/resources/add.pte unzip -o "$BUILD_AAR_DIR"/model.zip -d "$BASEDIR"/src/androidTest/resources From b8b5146ee8d609f1f17c5c823ada5cb0d6f95be3 Mon Sep 17 00:00:00 2001 From: David Lin Date: Mon, 11 Nov 2024 13:00:25 -0800 Subject: [PATCH 52/59] move junit tests to android_test (#6761) Move JUnit tests to android_test to centralize both unit tests and instrumentation tests. --- extension/android/build.gradle | 1 - .../src/test/java/org/pytorch/executorch/EValueTest.java | 2 +- .../src/test/java/org/pytorch/executorch/TensorTest.java | 0 3 files changed, 1 insertion(+), 2 deletions(-) rename extension/{android => android_test}/src/test/java/org/pytorch/executorch/EValueTest.java (99%) rename extension/{android => android_test}/src/test/java/org/pytorch/executorch/TensorTest.java (100%) diff --git a/extension/android/build.gradle b/extension/android/build.gradle index de243154d6..b40f08e0c4 100644 --- a/extension/android/build.gradle +++ b/extension/android/build.gradle @@ -20,6 +20,5 @@ task makeJar(type: Jar) { dependencies { implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' implementation 'com.facebook.soloader:nativeloader:0.10.5' - testImplementation 'junit:junit:4.13.2' } } diff --git a/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java similarity index 99% rename from extension/android/src/test/java/org/pytorch/executorch/EValueTest.java rename to extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java index 35367883ef..29cabae75f 100644 --- a/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java +++ b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java @@ -129,7 +129,7 @@ public void testOptionalTensorListValue() { Optional.of(Tensor.fromBlob(data[1], shape[1]))); assertTrue(evalue.isOptionalTensorList()); - assertTrue(evalue.toOptionalTensorList()[0].isEmpty()); + assertTrue(!evalue.toOptionalTensorList()[0].isPresent()); assertTrue(evalue.toOptionalTensorList()[1].isPresent()); assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0])); diff --git a/extension/android/src/test/java/org/pytorch/executorch/TensorTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java similarity index 100% rename from extension/android/src/test/java/org/pytorch/executorch/TensorTest.java rename to extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java From 6887ae960a3a85c713e2840ca8301331166ff6ee Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 11 Nov 2024 13:35:43 -0800 Subject: [PATCH 53/59] [ET-VK] Update partitioner to account for custom packed arguments (#6763) ## Problem Convolution operators, especially for pointwise convolution, may have sizes like ``` W=1, H=1, C=320, N=1280 ``` When represented as a texture, this tensor would normally require a texture with extents ``` (1, 1, 320 / 4 * 1280 = 102400) ``` which would normally exceed texture limits. The new partitioner system detects this and prevents nodes with similar weights from being lowered to Vulkan. However, the partitioner system does not account for the fact that the operator implementation uses a specialized prepacking algorithm which results in valid texture limits for the packed weights. ## Changes * Add field to `OpFeatures` class to annotate that some arguments in an op should be skipped when checking against texture limits * Update metadata tagging pass to ignore annotating constant tensor nodes so that they don't influence memory layout and storage type proposals. Without this change, the tagging pass will try to use buffer storage for the pointwise convolution since the weight can only be represented as a buffer under normal circumstances. Differential Revision: [D65759236](https://our.internmc.facebook.com/intern/diff/D65759236/) ghstack-source-id: 252885980 Pull Request resolved: https://github.com/pytorch/executorch/pull/6753 Co-authored-by: Stephen Jia --- .../vulkan/_passes/insert_prepack_nodes.py | 4 + .../vulkan/_passes/tag_memory_meta_pass.py | 73 +++++++++++++------ backends/vulkan/op_registry.py | 10 +++ .../vulkan/partitioner/vulkan_partitioner.py | 9 ++- 4 files changed, 73 insertions(+), 23 deletions(-) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 37665a6da8..7876806d6d 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -35,6 +35,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool: if not is_param_node(program, node): return True + # Annotate that this node is going to represented as a tensorref in the Vulkan + # compute graph. This will be useful for later graph passes. + node.meta["vkdg_tensorref"] = True + for user in node.users: if user.op == "call_function" and handles_own_prepacking( # pyre-ignore diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index fd0bd3648e..0a6a2d42d4 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -39,6 +39,30 @@ def set_memory_metadata( utils.set_node_spec_attr(node, "vk_memory_layout", layout) +def insert_transition_node( + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + arg: torch.fx.Node, + storage: VkStorageType, + layout: VkMemoryLayout, +) -> None: + """ + Insert a clone node to copy the original tensor to a tensor with the desired storage + type and memory layout. + """ + with graph_module.graph.inserting_before(node): + clone_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.clone.default, + (arg,), + ) + clone_node.meta["val"] = arg.meta["val"] + clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + clone_node.meta["spec"].const = False + set_memory_metadata(clone_node, storage, layout) + arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) + + class TagMemoryMetaPass(ExportPass): """ There are a variety of ways that tensors can be represented in Vulkan. The two main @@ -174,14 +198,33 @@ def propose_node_layout( else: return next(iter(valid_layouts)) + def should_annotate(self, node) -> bool: + if not isinstance(node, torch.fx.Node): + return False + + if not isinstance(node.meta["val"], FakeTensor): + return False + + # Storage type and memory layout for tensorref will be determined at runtime + # so there's no use in setting those attributes ahead of time. + if node.meta.get("vkdg_tensorref", False): + return False + + return True + + def should_delay_annotation(self, node: torch.fx.Node) -> bool: + # For prepack nodes, delay setting the storage type and memory layout as long as + # possible. This is to minimize the number of transitions, since it can be + # difficult to predict what storage type and memory layout should be used at the + # time the prepack node is observed. + return node.target == exir_ops.edge.et_vk.prepack.default + + # noqa def call(self, graph_module: torch.fx.GraphModule) -> PassResult: sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) for node in sorted_nodes: - if not isinstance(node.meta["val"], FakeTensor): - continue - - if node.target == exir_ops.edge.et_vk.prepack.default: + if not self.should_annotate(node) or self.should_delay_annotation(node): continue storage = self.propose_node_storage(node) @@ -191,11 +234,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: inserting_transitions_for_node = False for i, arg in enumerate(node.args): - if not isinstance(arg, torch.fx.Node): - continue - if not isinstance(arg.meta["val"], FakeTensor): + if not self.should_annotate(arg): continue + assert isinstance(arg, torch.fx.Node) + arg_storage = utils.get_node_storage_type(arg) arg_layout = utils.get_node_memory_layout(arg) @@ -215,22 +258,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" ) + insert_transition_node(graph_module, node, arg, storage, layout) + logger.info( f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" ) - # Insert a clone node to copy the original tensor to a tensor with the - # desired storage type and memory layout. - with graph_module.graph.inserting_before(node): - clone_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten.clone.default, - (arg,), - ) - clone_node.meta["val"] = arg.meta["val"] - clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) - clone_node.meta["spec"].const = False - set_memory_metadata(clone_node, storage, layout) - arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) - return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 3a6191bccb..eeec5ab37e 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -90,6 +90,9 @@ class OpFeatures: # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. "handles_own_prepacking", + # Optional dictionary to specify a custom function to calculate the required + # image extents for a particular argument index. + "skip_limits_check", # Optional check function used during partitioning to determine if a node's # inputs are supported by the operator implementation. "check_node_fn", @@ -103,6 +106,7 @@ def __init__( optimal_storage: Optional[VkStorageType] = None, optimal_layout: Optional[VkMemoryLayout] = None, handles_own_prepacking: bool = False, + skip_limits_check: Optional[Set[int]] = None, check_node_fn: Optional[Callable] = None, ): self.texture_impl: Optional[TextureImplFeatures] = texture_impl @@ -111,6 +115,11 @@ def __init__( self.optimal_storage: Optional[VkStorageType] = optimal_storage self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout self.handles_own_prepacking: bool = handles_own_prepacking + + self.skip_limits_check: Set[int] = set() + if skip_limits_check is not None: + self.skip_limits_check = skip_limits_check + self.check_node_fn: Callable = allow_node if check_node_fn is not None: self.check_node_fn = check_node_fn @@ -433,6 +442,7 @@ def register_convolution_op(features: OpFeatures): features.optimal_storage = VkStorageType.TEXTURE_3D features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED features.handles_own_prepacking = True + features.skip_limits_check = {1, 2} return features diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 7b2ad3fdfd..64e672fd69 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -82,8 +82,13 @@ def op_node_is_compatible( valid_texture_layouts = utils.possible_node_memory_layouts( node, self.texture_limits ) - for arg in node.args: - if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): + + for i, arg in enumerate(node.args): + if ( + isinstance(arg, torch.fx.Node) + and utils.is_tensor_node(arg) + and i not in features.skip_limits_check + ): arg_texture_layouts = utils.possible_node_memory_layouts( arg, self.texture_limits ) From 671f9c50ca1d581d4d34e38dbe8509c8040d5323 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 11 Nov 2024 14:09:50 -0800 Subject: [PATCH 54/59] update llama runner to decode single token (#6768) Pull Request resolved: https://github.com/pytorch/executorch/pull/6703 Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. ghstack-source-id: 252924039 Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) Co-authored-by: Lunwen He --- .ci/scripts/test_llama_runner_eager.sh | 3 ++- examples/models/llama/runner/eager.py | 16 ++++++++----- examples/models/llama/runner/generation.py | 25 ++++++++------------- examples/models/llama/runner/native.py | 8 ++----- examples/models/llama/tokenizer/tiktoken.py | 12 ++++++++++ extension/llm/tokenizer/tokenizer.py | 4 ++++ 6 files changed, 39 insertions(+), 29 deletions(-) diff --git a/.ci/scripts/test_llama_runner_eager.sh b/.ci/scripts/test_llama_runner_eager.sh index 537d835ba1..0f2cb7b376 100644 --- a/.ci/scripts/test_llama_runner_eager.sh +++ b/.ci/scripts/test_llama_runner_eager.sh @@ -42,11 +42,12 @@ run_and_verify() { -d fp32 \ --max_seq_length 32 \ --temperature 0 \ + --show_tokens \ --prompt "Once upon a time," > result.txt # Verify result.txt RESULT=$(cat result.txt) - EXPECTED_RESULT="there was a little girl" + EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263" if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then echo "Actual result: ${RESULT}" echo "Success" diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index b8792151a0..abac920c6b 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -63,6 +63,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=0, ) + parser.add_argument( + "--show_tokens", + action="store_true", + default=False, + help="Show the tokens that were generated", + ) + return parser @@ -71,15 +78,12 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + if args.show_tokens: + print(f"Tokens: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 867c41aabe..159bc5f501 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import List, Optional, TypedDict +from typing import List, Optional import torch @@ -13,11 +13,6 @@ from executorch.extension.llm.tokenizer.utils import get_tokenizer -class CompletionPrediction(TypedDict, total=False): - generation: str - tokens: List[int] # not required - - def sample_top_p(probs, p): """ Perform top-p (nucleus) sampling on a probability distribution. @@ -84,6 +79,7 @@ def generate( # noqa: C901 ) current_token = next_token(logits, temperature, top_p) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] while len(tokens) < self.params.max_seq_len: @@ -101,12 +97,14 @@ def generate( # noqa: C901 tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), ) current_token = next_token(logits, temperature, top_p) + tokens.append(current_token) if current_token == self.tokenizer.eos_id or ( hasattr(self.tokenizer, "stop_tokens") and current_token in self.tokenizer.stop_tokens ): break - tokens.append(current_token) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + print("\n") return tokens if echo else tokens[len(prompt_tokens) :] @@ -116,7 +114,7 @@ def text_completion( temperature: float = 0.6, top_p: float = 0.9, echo: bool = False, - ) -> CompletionPrediction: + ) -> List[int]: """ Perform text completion for a prompt using the language model. @@ -127,19 +125,14 @@ def text_completion( echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. Returns: - CompletionPrediction: Completion prediction, which contains the generated text completion. + Generated list of tokens. Note: This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. """ - prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) - generation_tokens = self.generate( - prompt_tokens=prompt_tokens, + return self.generate( + prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), temperature=temperature, top_p=top_p, echo=echo, ) - return { - "generation": self.tokenizer.decode(generation_tokens), - "tokens": generation_tokens, - } diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 73005d9333..19e5791598 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -107,15 +107,11 @@ def main() -> None: parser = build_args_parser() args = parser.parse_args() runner = NativeLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + print(f"Response: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/tokenizer/tiktoken.py b/examples/models/llama/tokenizer/tiktoken.py index 1d74e5e3aa..b48cb4dc89 100644 --- a/examples/models/llama/tokenizer/tiktoken.py +++ b/examples/models/llama/tokenizer/tiktoken.py @@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str: # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. return self.model.decode(cast(List[int], t)) + def decode_token(self, t: int) -> str: + """ + Decodes a single token ID into a string. + + Args: + t (int): The token ID to be decoded. + + Returns: + str: The decoded string. + """ + return self.model.decode_single_token_bytes(t).decode("utf-8") + @staticmethod def _split_whitespaces_or_nonwhitespaces( s: str, max_consecutive_slice_len: int diff --git a/extension/llm/tokenizer/tokenizer.py b/extension/llm/tokenizer/tokenizer.py index ecd0231fb6..78377230b9 100644 --- a/extension/llm/tokenizer/tokenizer.py +++ b/extension/llm/tokenizer/tokenizer.py @@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str: # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. return self.sp_model.decode(t) + def decode_token(self, t: int) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + def export(self, output_path: str, *, prepend_padding: bool = False) -> None: """ Export tokenizer.model to another serialization format. Here we did some lightweight From 623a9a61a860ed2e18364cf0715c5c898209427b Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 11 Nov 2024 14:13:34 -0800 Subject: [PATCH 55/59] add the ability to have multi-round conversation with llama (#6769) * update llama runner to decode single token Pull Request resolved: https://github.com/pytorch/executorch/pull/6703 Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. ghstack-source-id: 252924039 Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) * add the ability to have multi-round conversation with llama Ad the ability to have multi-round conversations with LLM. This will be helpful for testing long context length. Differential Revision: [D65771122](https://our.internmc.facebook.com/intern/diff/D65771122/) ghstack-source-id: 252934165 Pull Request resolved: https://github.com/pytorch/executorch/pull/6758 --------- Co-authored-by: Lunwen He --- examples/models/llama/runner/eager.py | 19 ++++++-- examples/models/llama/runner/generation.py | 53 +++++++++++++++++++++- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index abac920c6b..9745fdd542 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--prompt", type=str, - default="Hello", + default=None, ) parser.add_argument( @@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Show the tokens that were generated", ) + parser.add_argument( + "--chat", + action="store_true", + default=False, + help="Have multi-turn chat with the model", + ) + return parser @@ -78,9 +85,13 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - generated_tokens = runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, + generated_tokens = ( + runner.chat_completion(temperature=args.temperature) + if args.chat + else runner.text_completion( + prompt=args.prompt, + temperature=args.temperature, + ) ) if args.show_tokens: print(f"Tokens: {generated_tokens}") diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 159bc5f501..ed25d44b6f 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -67,12 +67,13 @@ def generate( # noqa: C901 temperature: float = 0.8, top_p: float = 0.9, echo: bool = False, + pos_base: int = 0, ) -> List[int]: # prefill logits = self.forward( tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( - torch.tensor([0], dtype=torch.long, device=self.device) + torch.tensor([pos_base], dtype=torch.long, device=self.device) if self.params.use_kv_cache else None ), @@ -89,7 +90,9 @@ def generate( # noqa: C901 [[current_token]], dtype=torch.long, device=self.device ), input_pos=torch.tensor( - [len(tokens) - 1], dtype=torch.long, device=self.device + [pos_base + len(tokens) - 1], + dtype=torch.long, + device=self.device, ), ) else: @@ -136,3 +139,49 @@ def text_completion( top_p=top_p, echo=echo, ) + + def chat_completion( + self, + temperature: float = 0.6, + top_p: float = 0.9, + ) -> List[int]: + """ + Perform multi-turn chat with the language model. + + Args: + prompt (str): Text prompt for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Generated list of tokens. + + Note: + This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. + """ + exit_prompt = "exit" + tokens = [] + prompt = input("Me: ") + while prompt and prompt != exit_prompt: + print("LLM: ", end="", flush=True) + new_tokens = self.generate( + prompt_tokens=self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ), + temperature=temperature, + top_p=top_p, + echo=True, + pos_base=len(tokens), + ) + tokens.extend(new_tokens) + prompt = input("Me: ") + return tokens + + def _format_prompt(self, prompt: str) -> str: + return f""" +<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|> + +{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" From 576e96cfd3e764006f5cc5b2f6bfdc6e93f0cbbf Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Tue, 12 Nov 2024 06:39:53 +0800 Subject: [PATCH 56/59] Qualcomm AI Engine Direct - Add llama sha transforming pass Differential Revision: D64435128 Pull Request resolved: https://github.com/pytorch/executorch/pull/6211 --- examples/models/llama/TARGETS | 1 + examples/models/llama/export_llama.py | 3 + examples/models/llama/export_llama_lib.py | 59 +++-- examples/models/llama/llama_transformer.py | 1 - .../llama/source_transformation/attention.py | 219 ++++++++++++++++++ 5 files changed, 267 insertions(+), 16 deletions(-) create mode 100644 examples/models/llama/source_transformation/attention.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index d328adffbf..cf387bfab2 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -82,6 +82,7 @@ runtime.python_library( "export_llama_lib.py", "model.py", "source_transformation/apply_spin_quant_r1_r2.py", + "source_transformation/attention.py", "source_transformation/lora.py", "source_transformation/pre_quantization.py", "source_transformation/prune_vocab.py", diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index 3d0d1b7bcf..1899ccf4df 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -7,11 +7,14 @@ # Example script for exporting Llama2 to flatbuffer import logging +import sys import torch from .export_llama_lib import build_args_parser, export_llama +sys.setrecursionlimit(4096) + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e1a8d1d06b..817f116c92 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -50,6 +50,8 @@ fuse_layer_norms, get_model_with_r1_r2, ) + +from .source_transformation.attention import replace_attention_to_attention_sha from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, @@ -175,6 +177,12 @@ def build_args_parser() -> argparse.ArgumentParser: help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", ) + parser.add_argument( + "--use_qnn_sha", + action="store_true", + help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)", + ) + parser.add_argument( "--calibration_tasks", nargs="+", @@ -700,15 +708,24 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 get_custom_quant_ios_dtype, ) + atten = builder_exported_to_edge.model.layers[0].attention + if args.use_qnn_sha: + cache_shape = torch.Size( + (atten.max_batch_size, atten.max_seq_len, atten.head_dim) + ) + else: + cache_shape = torch.Size( + ( + atten.max_batch_size, + atten.max_seq_len, + atten.n_kv_heads, + atten.head_dim, + ) + ) # pyre-ignore tag_quant_io( builder_exported_to_edge.edge_manager.exported_program().graph_module, - partial( - get_custom_quant_ios_dtype, # pyre-ignore - builder_exported_to_edge.model.layers[ - 0 - ].attention.kv_cache.past_k_caches.shape, - ), + partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore ) logging.info("Lowering model using following partitioner(s): ") @@ -977,15 +994,27 @@ def _get_source_transforms( # noqa convert_linear_to_conv2d, ) - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. - transforms.append(convert_linear_to_conv2d) + if args.use_qnn_sha: + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(replace_attention_to_attention_sha) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + transforms.append(convert_linear_to_conv2d) + else: + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append( + get_model_with_r1_r2(args.optimized_rotation_path) + ) + transforms.append(convert_linear_to_conv2d) elif args.mps: # Currently mps doesn't support sdpa op, use the simpler decomposition diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 76e8730328..20b8b1e30d 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -276,7 +276,6 @@ def __init__(self, args: ModelArgs, layer_id: int): self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim - # self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125 self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py new file mode 100644 index 0000000000..c5a028d340 --- /dev/null +++ b/examples/models/llama/source_transformation/attention.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Example script for exporting Llama2 to flatbuffer + +import math +from typing import List, Optional, Tuple + +import torch +from executorch.examples.models.llama.llama_transformer import Attention +from torch import nn + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + x_r, x_i = x[..., ::2], x[..., 1::2] + + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class KVCacheSHA(torch.nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.float32, + ): + super().__init__() + + # a buffer per head + cache_shape = (max_batch_size, max_seq_length, head_dim) + for i in range(n_heads): + self.register_buffer( + f"past_k_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + self.register_buffer( + f"past_v_caches_{i}", + torch.zeros(cache_shape, dtype=dtype, device="cpu"), + persistent=False, + ) + + def update( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + cache_idx: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + new_k = torch.ops.aten.index_put_( + getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val + ) + new_v = torch.ops.aten.index_put_( + getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val + ) + return new_k, new_v + + def get_cache(self, head_idx): + return getattr(self, f"past_k_caches_{head_idx}"), getattr( + self, f"past_v_caches_{head_idx}" + ) + + +class SDPASHA(torch.nn.Module): + + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + n_rep: int, + head_dim: int, + dim: int, + ): + super().__init__() + self.head_dim = head_dim + self.n_rep = n_rep + self.dim = dim + self.kv_cache = KVCacheSHA( + max_batch_size, max_seq_length, n_heads // n_rep, head_dim + ) + self.scale_factor = math.sqrt(head_dim) + + def forward( + self, + input_pos: torch.Tensor, + qs: List[torch.Tensor], + ks: List[torch.Tensor], + vs: List[torch.Tensor], + mask, + ): + + transpose_ks = [] + for i in range(len(ks)): + new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i) + transpose_ks.append(new_k.transpose(-2, -1).contiguous()) + + output = [] + for i, q in enumerate(qs): + cache_idx = i // self.n_rep + _, v = self.kv_cache.get_cache(cache_idx) + + attn_mask = mask[input_pos] + + attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + output.append(attn_weight @ v.contiguous()) + + return torch.cat(output, dim=-1) + + +class AttentionSHA(nn.Module): + def __init__(self, attention_mha: nn.Module): + super().__init__() + if not attention_mha.use_kv_cache: + raise NotImplementedError("bert mode is not support") + + self.n_heads = attention_mha.n_heads + self.n_kv_heads = attention_mha.n_kv_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.dim = attention_mha.dim + self.max_batch_size = attention_mha.max_batch_size + self.max_seq_len = attention_mha.max_seq_len + self.head_dim = attention_mha.dim // self.n_heads + self.SDPA = SDPASHA( + self.max_batch_size, + self.max_seq_len, + self.n_heads, + self.n_rep, + self.head_dim, + self.dim, + ) + self.wq = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wv = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + + for i in range(self.n_heads): + self.wq[i].weight.data.copy_( + attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + for i in range(self.n_kv_heads): + self.wk[i].weight.data.copy_( + attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wv[i].weight.data.copy_( + attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wo = attention_mha.wo + + causal_mask = torch.tril( + torch.ones( + self.max_seq_len, + self.max_seq_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + ): + # QKV + q = [wq(x) for wq in self.wq] + k = [wk(x) for wk in self.wk] + v = [wv(x) for wv in self.wv] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin) + + output = self.SDPA(input_pos, q, k, v, self.mask) + return self.wo(output) + + +def replace_attention_to_attention_sha(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, Attention): + setattr( + module, + name, + AttentionSHA(child), + ) + else: + replace_attention_to_attention_sha(child) + return module From f90cf2d0e990d3ec98d1af0f092bb673dbe7b2b1 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 11 Nov 2024 17:46:06 -0500 Subject: [PATCH 57/59] Tighten type hints for tensor arithmetic Differential Revision: D65753120 Pull Request resolved: https://github.com/pytorch/executorch/pull/6752 --- devtools/inspector/_inspector_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index c2e92f0914..83492f9963 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -218,6 +218,7 @@ def verify_debug_data_equivalence( if isinstance(output_a, torch.Tensor): assert bool( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `bool`. torch.all(output_a == output_b) ), "Tensors Debug Data is different. Expected to be equal." else: From 4947e273709366d256c5745ba21f750ad81c211a Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Mon, 11 Nov 2024 14:52:31 -0800 Subject: [PATCH 58/59] Fix internal pyre test Differential Revision: D65782663 Pull Request resolved: https://github.com/pytorch/executorch/pull/6770 --- backends/arm/arm_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index b55f237543..47c3c2d5e5 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -52,6 +52,7 @@ def __init__(self): # TODO MLETORCH-265 Remove permute_nhwc flag self.permute_nhwc = False self.quantize_io = False + self.tosa_version = None def ethosu_compile_spec( self, From dc41596b1a61ca3de7b5ad35e0692a78d7b185eb Mon Sep 17 00:00:00 2001 From: JP <46308822+zonglinpeng@users.noreply.github.com> Date: Mon, 11 Nov 2024 20:40:45 -0800 Subject: [PATCH 59/59] migrate utils from jarvis to cadence Differential Revision: D65458848 Pull Request resolved: https://github.com/pytorch/executorch/pull/6720 --- backends/cadence/aot/TARGETS | 13 ++ backends/cadence/aot/pass_utils.py | 8 +- .../cadence/aot/tests/test_pass_filter.py | 160 ++++++++++++++++++ 3 files changed, 177 insertions(+), 4 deletions(-) create mode 100644 backends/cadence/aot/tests/test_pass_filter.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 9876e59dbf..74deed0628 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -11,6 +11,7 @@ load( "CXX", ) load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") oncall("odai_jarvis") @@ -103,3 +104,15 @@ executorch_generated_lib( "//executorch/kernels/portable:operators", ], ) + +python_unittest( + name = "test_pass_filter", + srcs = [ + "tests/test_pass_filter.py", + ], + typing = True, + deps = [ + ":pass_utils", + "//executorch/exir:pass_base", + ], +) diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 3aa6f48a31..12a2f62238 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -28,11 +28,11 @@ class CadencePassAttribute: # A dictionary that maps an ExportPass to its attributes. -_ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} +ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute: - return _ALL_CADENCE_PASSES[p] + return ALL_CADENCE_PASSES[p] # A decorator that registers a pass. @@ -40,14 +40,14 @@ def register_cadence_pass( pass_attribute: CadencePassAttribute, ) -> Callable[[ExportPass], ExportPass]: def wrapper(cls: ExportPass) -> ExportPass: - _ALL_CADENCE_PASSES[cls] = pass_attribute + ALL_CADENCE_PASSES[cls] = pass_attribute return cls return wrapper def get_all_available_cadence_passes() -> Set[ExportPass]: - return set(_ALL_CADENCE_PASSES.keys()) + return set(ALL_CADENCE_PASSES.keys()) # Create a new filter to filter out relevant passes from all Jarvis passes. diff --git a/backends/cadence/aot/tests/test_pass_filter.py b/backends/cadence/aot/tests/test_pass_filter.py new file mode 100644 index 0000000000..7b49ef5c32 --- /dev/null +++ b/backends/cadence/aot/tests/test_pass_filter.py @@ -0,0 +1,160 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-unsafe + + +import unittest + +from copy import deepcopy + +from executorch.backends.cadence.aot import pass_utils +from executorch.backends.cadence.aot.pass_utils import ( + ALL_CADENCE_PASSES, + CadencePassAttribute, + create_cadence_pass_filter, + register_cadence_pass, +) + +from executorch.exir.pass_base import ExportPass + + +class TestBase(unittest.TestCase): + def setUp(self): + # Before running each test, create a copy of _all_passes to later restore it after test. + # This avoids messing up the original _all_passes when running tests. + self._all_passes_original = deepcopy(ALL_CADENCE_PASSES) + # Clear _all_passes to do a clean test. It'll be restored after each test in tearDown(). + pass_utils.ALL_CADENCE_PASSES.clear() + + def tearDown(self): + # Restore _all_passes to original state before test. + pass_utils.ALL_CADENCE_PASSES = self._all_passes_original + + def get_filtered_passes(self, filter_): + return {cls: attr for cls, attr in ALL_CADENCE_PASSES.items() if filter_(cls)} + + +# Test pass registration +class TestPassRegistration(TestBase): + def test_register_cadence_pass(self): + pass_attr_O0 = CadencePassAttribute(opt_level=0) + pass_attr_debug = CadencePassAttribute(opt_level=None, debug_pass=True) + pass_attr_O1_all_backends = CadencePassAttribute( + opt_level=1, + ) + + # Register 1st pass with opt_level=0 + @register_cadence_pass(pass_attr_O0) + class DummyPass_O0(ExportPass): + pass + + # Register 2nd pass with opt_level=1, all backends. + @register_cadence_pass(pass_attr_O1_all_backends) + class DummyPass_O1_All_Backends(ExportPass): + pass + + # Register 3rd pass with opt_level=None, debug=True + @register_cadence_pass(pass_attr_debug) + class DummyPass_Debug(ExportPass): + pass + + # Check if the three passes are indeed added into _all_passes + expected_all_passes = { + DummyPass_O0: pass_attr_O0, + DummyPass_Debug: pass_attr_debug, + DummyPass_O1_All_Backends: pass_attr_O1_all_backends, + } + self.assertEqual(pass_utils.ALL_CADENCE_PASSES, expected_all_passes) + + +# Test pass filtering +class TestPassFiltering(TestBase): + def test_filter_none(self): + pass_attr_O0 = CadencePassAttribute(opt_level=0) + pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True) + pass_attr_O1_all_backends = CadencePassAttribute( + opt_level=1, + ) + + @register_cadence_pass(pass_attr_O0) + class DummyPass_O0(ExportPass): + pass + + @register_cadence_pass(pass_attr_O1_debug) + class DummyPass_O1_Debug(ExportPass): + pass + + @register_cadence_pass(pass_attr_O1_all_backends) + class DummyPass_O1_All_Backends(ExportPass): + pass + + O1_filter = create_cadence_pass_filter(opt_level=1, debug=True) + O1_filter_passes = self.get_filtered_passes(O1_filter) + + # Assert that no passes are filtered out. + expected_passes = { + DummyPass_O0: pass_attr_O0, + DummyPass_O1_Debug: pass_attr_O1_debug, + DummyPass_O1_All_Backends: pass_attr_O1_all_backends, + } + self.assertEqual(O1_filter_passes, expected_passes) + + def test_filter_debug(self): + pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True) + pass_attr_O2 = CadencePassAttribute(opt_level=2) + + @register_cadence_pass(pass_attr_O1_debug) + class DummyPass_O1_Debug(ExportPass): + pass + + @register_cadence_pass(pass_attr_O2) + class DummyPass_O2(ExportPass): + pass + + debug_filter = create_cadence_pass_filter(opt_level=2, debug=False) + debug_filter_passes = self.get_filtered_passes(debug_filter) + + # Assert that debug passees are filtered out, since the filter explicitly + # chooses debug=False. + self.assertEqual(debug_filter_passes, {DummyPass_O2: pass_attr_O2}) + + def test_filter_all(self): + @register_cadence_pass(CadencePassAttribute(opt_level=1)) + class DummyPass_O1(ExportPass): + pass + + @register_cadence_pass(CadencePassAttribute(opt_level=2)) + class DummyPass_O2(ExportPass): + pass + + debug_filter = create_cadence_pass_filter(opt_level=0) + debug_filter_passes = self.get_filtered_passes(debug_filter) + + # Assert that all the passes are filtered out, since the filter only selects + # passes with opt_level <= 0 + self.assertEqual(debug_filter_passes, {}) + + def test_filter_opt_level_None(self): + pass_attr_O1 = CadencePassAttribute(opt_level=1) + pass_attr_O2_debug = CadencePassAttribute(opt_level=2, debug_pass=True) + + @register_cadence_pass(CadencePassAttribute(opt_level=None)) + class DummyPass_None(ExportPass): + pass + + @register_cadence_pass(pass_attr_O1) + class DummyPass_O1(ExportPass): + pass + + @register_cadence_pass(pass_attr_O2_debug) + class DummyPass_O2_Debug(ExportPass): + pass + + O2_filter = create_cadence_pass_filter(opt_level=2, debug=True) + filtered_passes = self.get_filtered_passes(O2_filter) + # Passes with opt_level=None should never be retained. + expected_passes = { + DummyPass_O1: pass_attr_O1, + DummyPass_O2_Debug: pass_attr_O2_debug, + } + self.assertEqual(filtered_passes, expected_passes)