From 4c43c86e968f679e82eee3e3be8ef48d3e55719a Mon Sep 17 00:00:00 2001 From: JP <46308822+zonglinpeng@users.noreply.github.com> Date: Fri, 6 Dec 2024 21:40:28 -0800 Subject: [PATCH 1/4] Revert "Reland cadence quantized_linear_per_tensor_out cpu 1eb924f^..fd33294" (#7225) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Revert "Reland cadence quantized_linear_per_tensor_out cpu 1eb924f^..fd33294 …" Caused merge conflict issue in diff train. This reverts commit 78b60df61bb2677b27935e4f83befbaea06861ee. --- backends/cadence/CMakeLists.txt | 4 +- backends/cadence/aot/functions.yaml | 50 -- .../reference/operators/CMakeLists.txt | 11 - .../reference/operators/im2row_out.cpp | 206 -------- .../cadence/reference/operators/operators.h | 57 --- .../operators/quantized_conv_out.cpp | 464 ++++-------------- .../operators/quantized_linear_out.cpp | 41 +- .../reference/operators/quantized_ops.h | 190 ------- .../cadence/reference/operators/targets.bzl | 3 - 9 files changed, 100 insertions(+), 926 deletions(-) delete mode 100644 backends/cadence/reference/operators/im2row_out.cpp delete mode 100644 backends/cadence/reference/operators/operators.h delete mode 100644 backends/cadence/reference/operators/quantized_ops.h diff --git a/backends/cadence/CMakeLists.txt b/backends/cadence/CMakeLists.txt index 6c71909c47..3cd880622c 100644 --- a/backends/cadence/CMakeLists.txt +++ b/backends/cadence/CMakeLists.txt @@ -23,6 +23,7 @@ include(${EXECUTORCH_ROOT}/build/Utils.cmake) # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) +set(TARGET_DIR reference) if(EXECUTORCH_CADENCE_CPU_RUNNER) include(${EXECUTORCH_ROOT}/build/Codegen.cmake) @@ -60,9 +61,6 @@ if(EXECUTORCH_CADENCE_CPU_RUNNER) ${_common_include_directories} ) - set(TARGET_DIR reference) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) - target_link_libraries( cadence_runner executorch diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index f1a5b6a50b..e7c16d0031 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -142,41 +142,6 @@ - arg_meta: null kernel_name: torch::executor::where_out -- op: transpose_copy.int_out - kernels: - - arg_meta: null - kernel_name: torch::executor::transpose_copy_int_out - -- op: eq.Scalar_out - kernels: - - arg_meta: null - kernel_name: torch::executor::eq_scalar_out - -- op: logical_not.out - kernels: - - arg_meta: null - kernel_name: torch::executor::logical_not_out - -- op: any.out - kernels: - - arg_meta: null - kernel_name: torch::executor::any_out - -- op: native_group_norm.out - kernels: - - arg_meta: null - kernel_name: torch::executor::native_group_norm_out - -- op: sum.IntList_out - kernels: - - arg_meta: null - kernel_name: torch::executor::sum_dim_out - -- op: select_copy.int_out - kernels: - - arg_meta: null - kernel_name: torch::executor::select_copy_int_out - # custom ops - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function @@ -218,18 +183,3 @@ kernels: - arg_meta: null kernel_name: impl::reference::quantized_matmul_out - -- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) - kernels: - - arg_meta: null - kernel_name: impl::reference::quantized_linear_per_tensor_out - -- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) - kernels: - - arg_meta: null - kernel_name: impl::reference::im2row_out - -- func: cadence::quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) - kernels: - - arg_meta: null - kernel_name: impl::reference::quantized_conv_per_tensor_out diff --git a/backends/cadence/reference/operators/CMakeLists.txt b/backends/cadence/reference/operators/CMakeLists.txt index a2d51af2c0..c40d3ff66b 100644 --- a/backends/cadence/reference/operators/CMakeLists.txt +++ b/backends/cadence/reference/operators/CMakeLists.txt @@ -55,16 +55,6 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_expand_copy.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_empty.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_transpose_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_eq.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_logical_not.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_any.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_native_group_norm.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sum.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_select_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/select_copy_util.cpp" ) add_library(aten_ops_cadence ${_aten_ops__srcs}) target_link_libraries(aten_ops_cadence PUBLIC executorch) @@ -88,7 +78,6 @@ add_library( "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp" "quantized_matmul_out.cpp" - "im2row_out.cpp" ) target_include_directories( custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} diff --git a/backends/cadence/reference/operators/im2row_out.cpp b/backends/cadence/reference/operators/im2row_out.cpp deleted file mode 100644 index dd539b6f9b..0000000000 --- a/backends/cadence/reference/operators/im2row_out.cpp +++ /dev/null @@ -1,206 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include - -#include - -namespace impl { -namespace reference { -namespace native { - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; - -template -__attribute__((always_inline)) void im2row_( - const T* __restrict__ data_im, - const int32_t in_zero_point, - /* input parameters*/ - const int32_t channels, - const int32_t height, - const int32_t width, - /* output parameters */ - const int32_t out_height, - const int32_t out_width, - /* convolution parameters */ - const int32_t kernel_h, - const int32_t kernel_w, - const int32_t pad_h, - const int32_t pad_w, - const int32_t stride_h, - const int32_t stride_w, - const int32_t dilation_h, - const int32_t dilation_w, - T* __restrict__ data_col, - bool channels_last) { - // Consider convolving the input image of dimensions channels * height * width - // (or height * width * channels for NHWC layout) with a filter of dimensions - // channels * kernels_h * kernels_w. Assume that this convolution will produce - // an output of dimensinos out_height x out_width. For each point the output, - // im2row takes the data from the input that is used in the computation of - // that output point, and flattens it into a vector of size channels_col = - // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D - // array of size (out_height * out_width) x channels_col - const int32_t channels_col = channels * kernel_h * kernel_w; - - // If the layout is NHWC, we can copy 'channels' worth of contiguous data - // points when performing im2row. - if (channels_last) { - // Iterate over the output domain - for (int _h = 0; _h < out_height; ++_h) { - for (int _w = 0; _w < out_width; ++_w) { - int32_t i_col = _h * out_width + _w; - // Each point in the output domain is the result of applying a filter of - // size kernel_h x kernel_w x channels on the input. But since channels - // is contiguous, we will not explicitly have a loop for it. - for (int _kh = 0; _kh < kernel_h; ++_kh) { - int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; - for (int _kw = 0; _kw < kernel_w; ++_kw) { - int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; - - // h_im and w_im are the actual height and width coordinates of the - // input tensor from where we need to copy 'channels' points. - const T* __restrict__ slice_im = - data_im + (h_im * width + w_im) * channels; - T* __restrict__ slice_col = data_col + i_col * channels_col + - (_kh * kernel_w + _kw) * channels; - // If the coordinates were within the input domain, we copy - // 'channels' contiguous values. Otherwise we will fill the output - // with 0's. - if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { - std::memcpy(slice_col, slice_im, channels * sizeof(T)); - } else { - std::fill_n(slice_col, channels, T(in_zero_point)); - } - } - } - } - } - } else { - // Iterate over the output domain - for (int _h = 0; _h < out_height; ++_h) { - for (int _w = 0; _w < out_width; ++_w) { - int32_t i_col = _h * out_width + _w; - - // Each point in the output domain is the result of applying a filter - // of size chanenls * kernel_h x kernel_w on the input - for (int _c = 0; _c < channels; ++_c) { - for (int _kh = 0; _kh < kernel_h; ++_kh) { - for (int _kw = 0; _kw < kernel_w; ++_kw) { - // c_col is the linearized access in the channels_col vector. - int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; - // h_im and w_im are the actual height and width coordinates of - // the input tensor that we need to copy to the output. - int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; - int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; - // If the current data access is within the input tensor, copy the - // value - data_col[i_col * channels_col + c_col] = - (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) - ? data_im[(_c * height + h_im) * width + w_im] - : static_cast(in_zero_point); - } - } - } - } - } - } -} - -void im2row_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef dilation, - IntArrayRef padding, - IntArrayRef stride, - const Tensor& in_zero_point, - bool channel_last, - Tensor& out) { - // Compute the input tensor's dims - bool unit_height = input.dim() == 3; - const int32_t batch_size = input.size(0); - const int32_t in_c = - channel_last ? input.size(3 - unit_height) : input.size(1); - const int32_t in_h = - unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); - const int32_t in_w = - channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); - - // Get the kernel parameters - int32_t kernel_h = kernel_size[0]; - int32_t kernel_w = kernel_size[1]; - int32_t dilation_h = dilation[0]; - int32_t dilation_w = dilation[1]; - int32_t pad_h = padding[0]; - int32_t pad_w = padding[1]; - int32_t stride_h = stride[0]; - int32_t stride_w = stride[1]; - - // If we were to apply a convolution on the input tensor, compute the output - // height and width. - int32_t out_h = - (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; - int32_t out_w = - (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; - - ET_DCHECK_MSG( - (out_h * out_w) == out.size(1), "dimension mismatch for output"); - ET_DCHECK_MSG( - (kernel_h * kernel_w * in_c) == out.size(2), - "dimension mismatch for output"); - - // Check if the input is per-tensor quantized or per-channel quantized. The - // zero point for each batch could differ for per-channel quantized input. - bool per_tensor_quantized = in_zero_point.numel() == 1; - -#define typed_im2row(dtype, ctype) \ - case ScalarType::dtype: { \ - const ctype* __restrict__ in_data = input.const_data_ptr(); \ - ctype* __restrict__ out_data = out.mutable_data_ptr(); \ - const int32_t* __restrict__ zero_point = \ - in_zero_point.const_data_ptr(); \ - int32_t in_plane = in_c * in_h * in_w; \ - int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ - for (size_t n = 0; n < batch_size; ++n) { \ - im2row_( \ - &in_data[n * in_plane], \ - per_tensor_quantized ? zero_point[0] : zero_point[n], \ - in_c, \ - in_h, \ - in_w, \ - out_h, \ - out_w, \ - kernel_h, \ - kernel_w, \ - pad_h, \ - pad_w, \ - stride_h, \ - stride_w, \ - dilation_h, \ - dilation_w, \ - &out_data[n * out_plane], \ - channel_last); \ - } \ - break; \ - } - - ScalarType dtype = input.scalar_type(); - switch (dtype) { - typed_im2row(Float, float); - typed_im2row(Byte, uint8_t); - typed_im2row(Char, int8_t); - default: - ET_DCHECK_MSG( - false, - "im2row not implemented for dtype %s", - torch::executor::toString(dtype)); - } -#undef typed_im2row -} - -} // namespace native -} // namespace reference -} // namespace impl diff --git a/backends/cadence/reference/operators/operators.h b/backends/cadence/reference/operators/operators.h deleted file mode 100644 index 0ff4639255..0000000000 --- a/backends/cadence/reference/operators/operators.h +++ /dev/null @@ -1,57 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include - -namespace cadence { -namespace impl { -namespace cpu { -namespace native { -namespace { -using ::executorch::runtime::getLeadingDims; - -#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ - _(uint8_t, Byte) \ - _(int8_t, Char) - -inline __attribute__((always_inline)) void linear_( - const ::executorch::aten::Tensor& input, - const ::executorch::aten::Tensor& weight, - const ::executorch::aten::optional<::executorch::aten::Tensor>& bias, - ::executorch::aten::Tensor& output) { - const float* __restrict__ input_data = input.const_data_ptr(); - const float* __restrict__ weight_data = weight.const_data_ptr(); - const float* __restrict__ bias_data = bias.value().const_data_ptr(); - float* __restrict__ output_data = output.mutable_data_ptr(); - - // input comes in shape [batch_size, in_dim] - // weight comes in shape [out_dim, in_dim] - // output comes in empty with shape [batch_size, out_dim] - // Perform matrix multiply (M x N) x (N x P) => M x P - int64_t M = weight.size(0); // = out_dim - int64_t N = weight.size(1); // = in_dim - - // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the - // leading dimensions is d0 * d1 * ... * d_{N-2} - int64_t leading_dims = getLeadingDims(input, input.dim() - 1); - - for (int i = 0; i < leading_dims; ++i) { - for (int j = 0; j < M; ++j) { - float sum = bias_data[j]; - for (int k = 0; k < N; ++k) { - sum += input_data[i * N + k] * weight_data[j * N + k]; - } - output_data[i * M + j] = sum; - } - } -} - -} // namespace -} // namespace native -} // namespace cpu -} // namespace impl -} // namespace cadence diff --git a/backends/cadence/reference/operators/quantized_conv_out.cpp b/backends/cadence/reference/operators/quantized_conv_out.cpp index 5a7af85809..de19f3ef43 100644 --- a/backends/cadence/reference/operators/quantized_conv_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_out.cpp @@ -1,16 +1,21 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include -#include + +#include namespace impl { namespace reference { namespace native { -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; // This implements a generic 2d conv kernel that operates on raw pointers. // The version handles both quantized and fp32 convolutions. @@ -18,12 +23,7 @@ using ::executorch::runtime::KernelRuntimeContext; // The weight is of shape [oc x wc x wh x ww], where wc == c // The output is of shape [n x oc x oh x ow] // The bias is of shape [oc] -template < - typename IT = float, - typename WT = IT, - typename BT = IT, - typename OT = IT, - bool quantized = false> +template __attribute__((noinline)) void conv2d_nchw_core_generic( // All the arrays const IT* __restrict__ p_in, @@ -56,10 +56,11 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( // input zero point IT in_zero_point = 0, // weight zero point - int32_t weight_zero_point = 0, - float bias_scale = 1, + const int32_t* __restrict__ weight_zero_point = nullptr, + const float* __restrict__ bias_scale = nullptr, float out_scale = 1, - OT out_zero_point = 0) { + OT out_zero_point = 0, + bool per_tensor_quantized = true) { float inv_out_scale = 1. / out_scale; bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; @@ -105,7 +106,7 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( int woff = _wh * ww + _ww; float lhs = in_plane[ioff] - in_zero_point; float rhs = weight_plane[woff] - - (quantized ? weight_zero_point : 0); + (quantized ? weight_zero_point[0] : 0); acc += lhs * rhs; } } @@ -125,7 +126,7 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( int woff = _wh * ww + _ww; float lhs = in_plane[ioff] - in_zero_point; float rhs = weight_plane[woff] - - (quantized ? weight_zero_point : 0); + (quantized ? weight_zero_point[0] : 0); acc += lhs * rhs; } } @@ -133,10 +134,11 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( } } if (quantized) { - float val = bias_scale * acc; + float val = + (per_tensor_quantized ? bias_scale[0] : bias_scale[_oc]) * + acc; out_plane[_oh * ow + _ow] = - ::impl::reference::kernels::quantize( - val, inv_out_scale, out_zero_point); + kernels::quantize(val, inv_out_scale, out_zero_point); } else { out_plane[_oh * ow + _ow] = acc; } @@ -147,149 +149,27 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( } } -template < - typename IT = float, - typename WT = IT, - typename BT = IT, - typename OT = IT, - bool quantized = false> -__attribute__((noinline)) void conv2d_nhwc_core_generic( - // All the arrays - const IT* __restrict__ p_in, - const WT* __restrict__ p_weight, - const BT* __restrict__ p_bias, - OT* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t h, - int32_t w, - int32_t c, - int32_t oc, - int32_t wh, - int32_t ww, - int32_t wc, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Optional args that are only relevant for quantized convolution - // input zero point - IT in_zero_point = 0, - // weight zero point - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - OT out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const IT* in_batch = p_in + _n * h * w * c; - OT* out_batch = p_out + _n * oh * ow * oc; - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - OT* out_line = out_batch + (_oh * ow + _ow) * oc; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - const WT* weight_batch = p_weight + _oc * wh * ww * wc; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of - // size h x w x icpg, with a stencil of size wh x ww x icpg, to - // compute an output channel of size oh x ow x 1. - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to - // the output channel being computed) with the corresponding - // weight channel. If the padding is 0, and dilation is 1, then - // we can remove the unnecessary checks, and simplify the code - // so that it can be vectorized by Tensilica compiler.x`` - if (zero_pad_unit_dilation) { - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - const IT* in_line = - in_batch + (_h + _wh) * w * c + (_w + _ww) * c; - const WT* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = in_line[_ic] - in_zero_point; - float rhs = weight_line[_ic - sic] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } else { - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - if (((_h + d0 * _wh - p0) >= 0) && - ((_h + d0 * _wh - p0) < h) && - ((_w + d1 * _ww - p1) >= 0) && - ((_w + d1 * _ww - p1 < w))) { - const IT* in_line = in_batch + - (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; - const WT* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = in_line[_ic] - in_zero_point; - float rhs = weight_line[_ic - sic] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } - } - if (quantized) { - float val = bias_scale * acc; - out_line[_oc] = ::impl::reference::kernels::quantize( - val, inv_out_scale, out_zero_point); - } else { - out_line[_oc] = acc; - } - } - } - } - } - } -} - // The quantized convolution kernel. in_scale and weight_scale are implicit in // bias_scale, since it is a product of the two. The kernel will branch to // quantized::conv1d or quantized::conv2d based on the dimensionality of // activation tensor. -void quantized_conv_nchw( +void quantized_conv_out( + KernelRuntimeContext& ctx, const Tensor& input, const Tensor& weight, const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + bool channel_last, Tensor& out) { bool conv1d = input.dim() == 3; // input = [n, c, h, w] @@ -306,224 +186,76 @@ void quantized_conv_nchw( const int oh = conv1d ? 1 : out.size(2); const int ow = conv1d ? out.size(2) : out.size(3); -#define typed_quantized_conv2d_nchw(ctype, dtype) \ - case ScalarType::dtype: { \ - conv2d_nchw_core_generic( \ - input.const_data_ptr(), \ - weight.const_data_ptr(), \ - bias.const_data_ptr(), \ - out.mutable_data_ptr(), \ - n, \ - c, \ - h, \ - w, \ - oc, \ - wc, \ - wh, \ - ww, \ - oh, \ - ow, \ - stride[0], \ - stride[1], \ - padding[0], \ - padding[1], \ - dilation[0], \ - dilation[1], \ - groups, \ - in_zero_point, \ - weight_zero_point, \ - bias_scale, \ - output_scale, \ - (ctype)output_zero_point); \ - break; \ - } - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nchw); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_conv2d_nchw -} - -void quantized_conv_nhwc( - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, h, w, c] - const int n = input.size(0); - const int h = conv1d ? 1 : input.size(1); - const int w = conv1d ? input.size(1) : input.size(2); - const int c = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wh, ww, wc] - const int oc = weight.size(0); - const int wh = conv1d ? 1 : weight.size(1); - const int ww = conv1d ? weight.size(1) : weight.size(2); - const int wc = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oh, ow, oc] - const int oh = conv1d ? 1 : out.size(1); - const int ow = conv1d ? out.size(1) : out.size(2); - -#define typed_quantized_conv2d_nhwc(ctype, dtype) \ - case ScalarType::dtype: { \ - conv2d_nhwc_core_generic( \ - input.const_data_ptr(), \ - weight.const_data_ptr(), \ - bias.const_data_ptr(), \ - out.mutable_data_ptr(), \ - n, \ - h, \ - w, \ - c, \ - oc, \ - wh, \ - ww, \ - wc, \ - oh, \ - ow, \ - stride[0], \ - stride[1], \ - padding[0], \ - padding[1], \ - dilation[0], \ - dilation[1], \ - groups, \ - in_zero_point, \ - weight_zero_point, \ - bias_scale, \ - output_scale, \ - (ctype)output_zero_point); \ - break; \ - } - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_conv2d_nhwc -} + // Bool flag to check if weight tensor is quantized per-tensor or + // per-channel + bool per_tensor_quantized = bias_scale.numel() == 1; -void quantized_conv_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - const Tensor& weight_zero_point, - const Tensor& bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED const Tensor& out_multiplier, - __ET_UNUSED const Tensor& out_shift, - bool channel_last, - Tensor& out) { - const float bias_scale_float = bias_scale.const_data_ptr()[0]; - const int32_t weight_zero_point_int = - weight_zero_point.const_data_ptr()[0]; - if (channel_last) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, + if (out.scalar_type() == exec_aten::ScalarType::Byte) { + conv2d_nchw_core_generic( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + c, + h, + w, + oc, + wc, + wh, + ww, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], groups, in_zero_point, - weight_zero_point_int, - bias_scale_float, + weight_zero_point.const_data_ptr(), + bias_scale.const_data_ptr(), output_scale, - output_zero_point, - out); - } else { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, + (uint8_t)output_zero_point, + per_tensor_quantized); + } else if (out.scalar_type() == exec_aten::ScalarType::Char) { + conv2d_nchw_core_generic( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + c, + h, + w, + oc, + wc, + wh, + ww, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], groups, in_zero_point, - weight_zero_point_int, - bias_scale_float, + weight_zero_point.const_data_ptr(), + bias_scale.const_data_ptr(), output_scale, - output_zero_point, - out); - } -} - -void quantized_conv_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - bool channel_last, - Tensor& out) { - if (channel_last) { - quantized_conv_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); + (int8_t)output_zero_point, + per_tensor_quantized); } else { - quantized_conv_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); } } -} // namespace native -} // namespace reference -} // namespace impl +}; // namespace native +}; // namespace reference +}; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_linear_out.cpp b/backends/cadence/reference/operators/quantized_linear_out.cpp index 4f7ca9cc3c..7bb1bf6fb4 100644 --- a/backends/cadence/reference/operators/quantized_linear_out.cpp +++ b/backends/cadence/reference/operators/quantized_linear_out.cpp @@ -6,8 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include +#include #include namespace impl { @@ -86,7 +85,6 @@ void quantized_linear_out( int64_t out_zero_point, __ET_UNUSED const executorch::aten::optional& offset, Tensor& out) { - // TODO: refactor to use switch case as quantized_linear_per_tensor_out if (out.scalar_type() == executorch::aten::ScalarType::Byte) { _typed_quantized_linear( src, @@ -117,43 +115,6 @@ void quantized_linear_out( } } -void quantized_linear_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - const int64_t src_zero_point, - const int64_t weight_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - const int64_t out_zero_point, - __ET_UNUSED const executorch::aten::optional& offset, - Tensor& out) { -#define typed_quantized_linear_per_tensor(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - src, \ - weight, \ - bias, \ - src_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - executorch::aten::ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); - } -#undef typed_quantized_linear_per_tensor -} - }; // namespace native }; // namespace reference }; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_ops.h b/backends/cadence/reference/operators/quantized_ops.h deleted file mode 100644 index 66545c8e58..0000000000 --- a/backends/cadence/reference/operators/quantized_ops.h +++ /dev/null @@ -1,190 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include - -template -inline __attribute__((always_inline)) void quantized_linear_per_tensor_( - const ::executorch::aten::Tensor& src, - const ::executorch::aten::Tensor& weight, - const ::executorch::aten::Tensor& bias, - const int64_t src_zero_point, - const int64_t weight_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - const int64_t out_zero_point, - ::executorch::aten::Tensor& out) { - // input comes in shape [leading_dims, in_dim] - // weight comes in shape [out_dim, in_dim] - // output comes in empty with shape [leading_dims, out_dim] - // Perform matrix multiply (M x N) x (N x P)' => M x P - const int64_t leading_dims = - executorch::runtime::getLeadingDims(src, src.dim() - 1); - const int64_t out_dim = weight.size(0); // = out_dim - const int64_t in_dim = weight.size(1); // = in_dim - - const T* __restrict__ in_data = src.const_data_ptr(); - const T* __restrict__ weight_data = weight.const_data_ptr(); - const int32_t* __restrict__ bias_data = bias.const_data_ptr(); - T* __restrict__ out_data = out.mutable_data_ptr(); - - // Compute the requant_scale from out_multiplier and out_shift - const float requant_scale = - -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); - - for (size_t i = 0; i < leading_dims; ++i) { - for (size_t j = 0; j < out_dim; ++j) { - int32_t sum = bias_data[j]; - for (size_t k = 0; k < in_dim; ++k) { - int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point; - int32_t w = - (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; - sum += x * w; - } - out_data[i * out_dim + j] = ::impl::reference::kernels::quantize( - sum, requant_scale, out_zero_point); - } - } -} - -template -inline __attribute__((always_inline)) void quantized_linear_per_tensor_( - const ::executorch::aten::Tensor& src, - const ::executorch::aten::Tensor& weight, - const ::executorch::aten::Tensor& bias, - int64_t src_zero_point, - const ::executorch::aten::Tensor& weight_zero_point_t, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - ::executorch::aten::Tensor& out) { - // Get the zero_point of weight. - int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; - quantized_linear_per_tensor_( - src, - weight, - bias, - src_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - out); -} - -template -inline __attribute__((always_inline)) void quantized_linear_per_channel_( - const ::executorch::aten::Tensor& src, - const ::executorch::aten::Tensor& weight, - const ::executorch::aten::Tensor& bias, - int64_t src_zero_point, - int64_t weight_zero_point, - const ::executorch::aten::Tensor& out_multiplier, - const ::executorch::aten::Tensor& out_shift, - int64_t out_zero_point, - ::executorch::aten::Tensor& out) { - // input comes in shape [leading_dims, in_dim] - // weight comes in shape [out_dim, in_dim] - // output comes in empty with shape [leading_dims, out_dim] - // Perform matrix multiply (M x N) x (N x P)' => M x P - int64_t leading_dims = - executorch::runtime::getLeadingDims(src, src.dim() - 1); - const int64_t out_dim = weight.size(0); // = out_dim - const int64_t in_dim = weight.size(1); // = in_dim - - const T* __restrict__ in_data = src.const_data_ptr(); - const T* __restrict__ weight_data = weight.const_data_ptr(); - const int32_t* __restrict__ bias_data = bias.const_data_ptr(); - T* __restrict__ out_data = out.mutable_data_ptr(); - const int32_t* __restrict__ out_multiplier_data = - out_multiplier.const_data_ptr(); - const int32_t* __restrict__ out_shift_data = - out_shift.const_data_ptr(); - - for (size_t i = 0; i < leading_dims; ++i) { - for (size_t j = 0; j < out_dim; ++j) { - int32_t sum = bias_data[j]; - for (size_t k = 0; k < in_dim; ++k) { - int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point; - int32_t w = - (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; - sum += x * w; - } - // Compute the out_scale from out_multiplier and out_shift - const float out_scale = - -out_multiplier_data[j] * 1.0 / (1 << 31) * pow(2, out_shift_data[j]); - out_data[i * out_dim + j] = ::impl::reference::kernels::quantize( - sum, out_scale, out_zero_point); - } - } -} - -template -inline __attribute__((always_inline)) void quantized_linear_( - const ::executorch::aten::Tensor& src, - const ::executorch::aten::Tensor& weight, - const ::executorch::aten::Tensor& bias, - int64_t src_zero_point, - int64_t weight_zero_point, - const ::executorch::aten::Tensor& out_multiplier, - const ::executorch::aten::Tensor& out_shift, - int64_t out_zero_point, - ::executorch::aten::Tensor& out) { - if (out_multiplier.numel() == 1) { - // Use per-tensor quantization kernel. - const int32_t* __restrict__ out_multiplier_data = - out_multiplier.const_data_ptr(); - const int32_t* __restrict__ out_shift_data = - out_shift.const_data_ptr(); - quantized_linear_per_tensor_( - src, - weight, - bias, - src_zero_point, - weight_zero_point, - out_multiplier_data[0], - out_shift_data[0], - out_zero_point, - out); - return; - } - - // Use per-channel quantization kernel. - quantized_linear_per_channel_( - src, - weight, - bias, - src_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - out); -} - -template -inline __attribute__((always_inline)) void quantized_linear_( - const ::executorch::aten::Tensor& src, - const ::executorch::aten::Tensor& weight, - const ::executorch::aten::Tensor& bias, - int64_t src_zero_point, - const ::executorch::aten::Tensor& weight_zero_point_t, - const ::executorch::aten::Tensor& out_multiplier, - const ::executorch::aten::Tensor& out_shift, - int64_t out_zero_point, - ::executorch::aten::Tensor& out) { - // Get the zero_point of weight. - int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; - quantized_linear_( - src, - weight, - bias, - src_zero_point, - weight_zero_point, - out_multiplier, - out_shift, - out_zero_point, - out); -} diff --git a/backends/cadence/reference/operators/targets.bzl b/backends/cadence/reference/operators/targets.bzl index 488aeebb82..347d476239 100644 --- a/backends/cadence/reference/operators/targets.bzl +++ b/backends/cadence/reference/operators/targets.bzl @@ -7,9 +7,6 @@ def define_common_targets(): srcs = glob([ "*.cpp", ]), - exported_headers =glob([ - "*.h", - ]), platforms = CXX, deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", From ffc1273a09c38d336d2a192abe3f6da02bdb1744 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Sun, 8 Dec 2024 22:13:17 -0800 Subject: [PATCH 2/4] [ET-VK] Store unique ptr to Tensor in Value instead of inlined tensor object, to reduce Value struct size from 448 to 80 bytes. Pull Request resolved: https://github.com/pytorch/executorch/pull/7145 This diff aims to reduce the size of the Value struct in the Executorch Vulkan runtime by storing a unique pointer to the Tensor object instead of an inlined tensor object. This change reduces the size of the Value struct from 448 bytes to 80 bytes, which can improve performance and reduce memory usage. ghstack-source-id: 256911524 @exported-using-ghexport Differential Revision: [D66655991](https://our.internmc.facebook.com/intern/diff/D66655991/) --------- Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> --- .../vulkan/runtime/graph/containers/Value.h | 55 ++++++++++++++----- .../vulkan/test/vulkan_compute_api_test.cpp | 4 +- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index 8773f0c0b0..83669c85b1 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -58,7 +58,7 @@ struct Value final { bool as_bool; } u; - api::vTensor as_tensor; + std::unique_ptr as_tensor; api::StagingBuffer as_staging; TensorRef as_tensorref; @@ -106,15 +106,18 @@ struct Value final { rhs.payload.member_name.~dtor_name(); \ break; +#define CASE_MOVE_UNIQUE_PTR_TYPE(type_tag, member_name) \ + case type_tag: \ + payload.member_name = std::move(rhs.payload.member_name); \ + break; + Value(Value&& rhs) noexcept : tag(rhs.tag) { switch (tag) { // Scalar types CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::INT, as_int); CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::DOUBLE, as_double); CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::BOOL, as_bool); - // Tensor and tensor adjacent types - CASE_MOVE_MOVEABLE_TYPE( - TypeTag::TENSOR, api::vTensor, as_tensor, vTensor); + // Tensor adjacent types CASE_MOVE_MOVEABLE_TYPE( TypeTag::STAGING, api::StagingBuffer, as_staging, StagingBuffer); CASE_MOVE_MOVEABLE_TYPE( @@ -132,6 +135,8 @@ struct Value final { CASE_MOVE_MOVEABLE_TYPE( TypeTag::STRING, std::string, as_string, basic_string); CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt); + // Tensor type + CASE_MOVE_UNIQUE_PTR_TYPE(TypeTag::TENSOR, as_tensor); case TypeTag::NONE: clearToNone(); @@ -142,6 +147,7 @@ struct Value final { #undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE #undef CASE_MOVE_MOVEABLE_TYPE +#undef CASE_MOVE_UNIQUE_PTR_TYPE // // Accessors @@ -157,9 +163,6 @@ struct Value final { ~Value() { switch (tag) { - case TypeTag::TENSOR: - payload.as_tensor.~vTensor(); - break; case TypeTag::STAGING: payload.as_staging.~StagingBuffer(); break; @@ -184,6 +187,9 @@ struct Value final { case TypeTag::SYMINT: payload.as_symint.~SymInt(); break; + case TypeTag::TENSOR: + payload.as_tensor.reset(); + break; // Manually list out the types so that if a type here is added later and // not handled the compiler can catch it. case TypeTag::NONE: @@ -252,12 +258,6 @@ struct Value final { return payload.member_name; \ } - SUPPORT_TRIVIALLY_MOVEABLE_TYPE( - api::vTensor, - Tensor, - TypeTag::TENSOR, - as_tensor); - SUPPORT_TRIVIALLY_MOVEABLE_TYPE( api::StagingBuffer, Staging, @@ -302,9 +302,36 @@ struct Value final { SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint); -#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE #undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE +#define SUPPORT_UNIQUE_PTR_TYPE(type, type_name, type_tag, member_name) \ + explicit Value(type t) : tag(type_tag) { \ + payload.member_name = std::make_unique(std::move(t)); \ + } \ + inline bool is##type_name() const { \ + return tag == type_tag; \ + } \ + inline type& to##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return *payload.member_name; \ + } \ + inline const type& toConst##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return *payload.member_name; \ + } + + SUPPORT_UNIQUE_PTR_TYPE(api::vTensor, Tensor, TypeTag::TENSOR, as_tensor); + +#undef SUPPORT_UNIQUE_PTR_TYPE + private: Payload payload; TypeTag tag; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index ee49f95ee2..6e491bed22 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1087,8 +1087,8 @@ TEST_F(VulkanComputeAPITest, print_object_sizes) { // Current known size on 64 bit system: 1040 B EXPECT_TRUE(sizeof(vTensor) < 1200); - // Current known size on 64 bit system: 1056 B - EXPECT_TRUE(sizeof(Value) < 1200); + // Current known size on 64 bit system: 120 B + EXPECT_TRUE(sizeof(Value) < 128); // Current known size on 64 bit system: 120 B EXPECT_TRUE(sizeof(StagingBuffer) < 500); // Current known size on 64 bit system: 384 B From 06e85a8d128a999d3b125eaa77d278c4a64319bb Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Sun, 8 Dec 2024 22:14:39 -0800 Subject: [PATCH 3/4] [executorch][schema] Add 'EXTERNAL' to DataLocation in schema Pull Request resolved: https://github.com/pytorch/executorch/pull/7191 To indicate if a tensor is external to the PTE file or not. Currently, we can also use the existence of 'fqn' to determine if a tensor is external or not. I think it's better to have a specific location field as fqn may be required for cases besides external tensor storage. ghstack-source-id: 257035024 @exported-using-ghexport Differential Revision: [D66523171](https://our.internmc.facebook.com/intern/diff/D66523171/) Co-authored-by: lucylq --- exir/passes/replace_view_copy_with_view_pass.py | 1 + exir/schema.py | 8 +++++++- exir/tensor.py | 5 ++++- schema/program.fbs | 16 +++++++++++++++- 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index 378b933211..b19cfbed95 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -109,6 +109,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None: "mem_obj_id", "mem_offset", "dtype", # property + "extra_tensor_info", # property ] # Make sure _self_fields and _base_fields are disjoint diff --git a/exir/schema.py b/exir/schema.py index 9ef294abb6..17810ff693 100644 --- a/exir/schema.py +++ b/exir/schema.py @@ -43,14 +43,20 @@ class TensorShapeDynamism(IntEnum): DYNAMIC_UNBOUND = 2 +class TensorDataLocation(IntEnum): + SEGMENT = 0 + EXTERNAL = 1 + + @dataclass class ExtraTensorInfo: """ Check program.fbs for explanations of this enum. """ - mutable_data_segments_idx: Optional[int] = None + mutable_data_segments_idx: int = 0 fully_qualified_name: Optional[str] = None + location: TensorDataLocation = TensorDataLocation.SEGMENT @dataclass diff --git a/exir/tensor.py b/exir/tensor.py index 0c5218bb59..a26f15a238 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -18,7 +18,7 @@ import executorch.exir.schema as schema import torch from executorch.exir.error import internal_assert -from executorch.exir.schema import ScalarType, TensorShapeDynamism +from executorch.exir.schema import ExtraTensorInfo, ScalarType, TensorShapeDynamism from executorch.exir.sym_util import eval_shape @@ -132,6 +132,7 @@ def __init__( is_sparse: bool = False, const: bool = False, requires_grad: bool = False, + extra_tensor_info: Optional[ExtraTensorInfo] = None, ) -> None: self.scalar_type = dtype self.const = const @@ -146,6 +147,7 @@ def __init__( self.is_sparse = is_sparse self.init_mem_planning_fields() self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape) + self.extra_tensor_info = extra_tensor_info @property def allocated_memory(self) -> int: @@ -346,6 +348,7 @@ def to_list( allocation_info=allocation_info, layout=layout_enum(spec.layout), shape_dynamism=spec.shape_dynamism, + extra_tensor_info=spec.extra_tensor_info, ) return flatbuffer_tensor diff --git a/schema/program.fbs b/schema/program.fbs index 064df063cf..7ab2175f8a 100644 --- a/schema/program.fbs +++ b/schema/program.fbs @@ -53,6 +53,13 @@ enum TensorShapeDynamism : byte { DYNAMIC_UNBOUND = 2, } +// Indicates where a tensor is stored. +enum TensorDataLocation : byte { + // Stored in a segment of the PTE file. + SEGMENT = 0, + // Stored outside of the PTE file. + EXTERNAL = 1, +} // Table to put additional information about tensors in that is not applicable // to the vast majority of tensors in the vast majority of programs. @@ -60,11 +67,18 @@ table ExtraTensorInfo { // [Optional] Specifies the SubsegmentOffsets in // program.mutable_data_segments that specifies where the data is located in. // If not present and the data is located in a segment, then the data is in - // the first index. + // index zero. mutable_data_segments_idx: uint64; // [Optional] The unique name of the tensor. e.g. 'mod.linear.weight' fully_qualified_name: string; + + // [Optional] Specifies where the tensor's data is stored. + // - SEGMENT (default): Data is stored in a segment. + // - EXTERNAL: Data is stored outside of the PTE file. fully_qualified_name + // must be non-empty, and is used as a key to find the tensor's external + // data. Tensor.data_buffer_idx is ignored. + location: TensorDataLocation; } table Tensor { From b9db0a3308502d4c60c6473e476479a5e14eb7be Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Sun, 8 Dec 2024 22:17:24 -0800 Subject: [PATCH 4/4] [executorch][emit] Refactor _tensor_spec_to_evalue ^ adding more logic to _tensor_spec_to_evalue in the next diff; simplifying it now. Otherwise, linter error on complexity. Differential Revision: [D66847875](https://our.internmc.facebook.com/intern/diff/D66847875/) ghstack-source-id: 256981105 Pull Request resolved: https://github.com/pytorch/executorch/pull/7233 Co-authored-by: lucylq --- exir/emit/_emitter.py | 103 +++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 42 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 381bab618c..2d6c066cce 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -48,6 +48,7 @@ from executorch.exir.passes.executorch_prim_ops_registry import is_sym_op from executorch.exir.print_program import _stacktrace_to_framelist, inspect_node from executorch.exir.schema import ( + AllocationDetails, BackendDelegate, BackendDelegateDataReference, BackendDelegateInlineData, @@ -328,6 +329,59 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue: ExportErrorType.NOT_SUPPORTED, f"Unknown list type: {val_type}" ) + def _get_allocation_info(self, spec: TensorSpec) -> AllocationDetails: + """Returns the allocation info for a given TensorSpec.""" + self._internal_assert_emitter( + isinstance(spec.mem_id, int) and spec.mem_id >= 0, + self.node, + f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}", + ) + + self._internal_assert_emitter( + isinstance(spec.mem_offset, int) and spec.mem_offset >= 0, + self.node, + f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}", + ) + try: + allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset) + except AddressSpaceOverflowException as e: + raise InternalError( + self._emit_node_specific_error( + self.node, + ( + f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, " + f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an " + f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) " + "during torch.export()." + ), + ) + ) + return allocation_info + + def _save_new_const_tensor( + self, + spec: TensorSpec, + buffer_data: bytes, + hashed: str, + allocation_info: Optional[AllocationDetails], + ) -> int: + """Saves a new constant tensor to the constant buffer and returns the buffer idx""" + + self.program_state.allocated_specs.append(spec) + # +1 because the first buffer location is reserved. + + # Update buffer_idx to point to the end of the list where we are adding the new buffer. + buffer = Buffer(storage=buffer_data) + if allocation_info: + buffer_idx = len(self.program_state.mutable_buffer) + self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx + self.program_state.mutable_buffer.append(buffer) + else: + buffer_idx = len(self.program_state.constant_buffer) + self.program_state.cached_spec_hash_values[hashed] = buffer_idx + self.program_state.constant_buffer.append(buffer) + return buffer_idx + def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: """Constructs an EValue from the given TensorSpec.""" @@ -339,35 +393,12 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: # default algos to set offsets, so need to check both. if spec.mem_id is not None and spec.mem_offset is not None: # Tensor is an activation. - self._internal_assert_emitter( - isinstance(spec.mem_id, int) and spec.mem_id >= 0, - self.node, - f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}", - ) - - self._internal_assert_emitter( - isinstance(spec.mem_offset, int) and spec.mem_offset >= 0, - self.node, - f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}", - ) - try: - allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset) - except AddressSpaceOverflowException as e: - raise InternalError( - self._emit_node_specific_error( - self.node, - ( - f"{e}\nHint: If you are using a memory pass based on dynamic shape bounds, " - f"such as ConstraintBasedSymShapeEvalPass, this may be the cause of an " - f"unbacked SymInt with its upper bound lazily set to 2^64-1 (uint64 max) " - "during torch.export()." - ), - ) - ) + allocation_info = self._get_allocation_info(spec) + # Tensor is either a constant tensor, or a mutable tensor with an initial state. if spec.const: # Tensor with a blob we need to serialize. May not actually be constant at runtime - # if it's a weight with an associated gradient + # if it's a weight with an associated gradient. spec_array_type = ( ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes() ) @@ -392,23 +423,11 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: else: buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) - # Haven't seen this constant before + # Haven't seen this constant before. if buffer_idx == -1: - # Update buffer_idx to point to the end of the list where we are adding the new buffer. - buffer = Buffer(storage=buffer_data) - self.program_state.allocated_specs.append(spec) - # +1 because the first buffer location is reserved - - if allocation_info: - buffer_idx = len(self.program_state.mutable_buffer) - self.program_state.cached_spec_mutable_hash_values[hashed] = ( - buffer_idx - ) - self.program_state.mutable_buffer.append(buffer) - else: - buffer_idx = len(self.program_state.constant_buffer) - self.program_state.cached_spec_hash_values[hashed] = buffer_idx - self.program_state.constant_buffer.append(buffer) + buffer_idx = self._save_new_const_tensor( + spec, buffer_data, hashed, allocation_info + ) if spec.const and spec.nbytes() != len(buffer_data): raise InternalError(