diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 847f764b0e..f07592fbfb 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -11,6 +11,9 @@ #include #include #include +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#endif /** * For an input tensor, use the scale and zero_point arguments to quantize it. @@ -22,6 +25,8 @@ namespace native { using Tensor = exec_aten::Tensor; using Scalar = exec_aten::Scalar; using ScalarType = exec_aten::ScalarType; +using StridesType = exec_aten::StridesType; +using SizesType = exec_aten::SizesType; namespace { @@ -63,6 +68,183 @@ void check_dequantize_per_tensor_args( quant_max); } +/** + * Useful to reduce a tensor `in` over a given dimension `dim` using the + * reduce function `fn`, which should have the following signature: + * void fn(const size_t size, const size_t stride, const size_t base_ix) + * where `size` and `stride` are the size and stride of the dimension being + * reduced and `base_ix` is the index of the first element of the reduction. + */ +template +void apply_over_unpacked_dim( + const Fn& fn, + const exec_aten::Tensor& in, + const int64_t& dim) { + if (in.numel() == 0) { + return; + } + + ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension"); + ET_CHECK_VALID_DIM(dim, in.dim()); + + const size_t d = ET_NORMALIZE_IX(dim, in.dim()); + const size_t dim_size = in.size(d); + const size_t outer_size = getLeadingDims(in, d); + const size_t inner_size = getTrailingDims(in, d); + // Loop through all outer dimensions + for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + // Loop through dim + for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size; + ++unpacked_dim_idx) { + fn(inner_size, outer_idx, unpacked_dim_idx); + } + } +} + +void dequantize_optimized( + const int8_t* in, + const double scale, + const int64_t zero_point, + float* out, + int64_t quant_min, + int64_t quant_max, + size_t numel) { + ET_CHECK_MSG( + zero_point >= quant_min, + "zero_point must be %" PRId64 " <= quant_min %" PRId64, + zero_point, + quant_min); + ET_CHECK_MSG( + zero_point <= quant_max, + "zero_point must be %" PRId64 " >= quant_max %" PRId64, + zero_point, + quant_max); + size_t i = 0; +#if defined(__aarch64__) || defined(__ARM_NEON) + int8x8_t zero_point_vec = vdup_n_s8(zero_point); + float32x4_t scales = vdupq_n_f32(static_cast(scale)); + constexpr int32_t kVecSize = 16; + const size_t num_vecs = numel / kVecSize; + const int8_t* in_copy = in; + float* out_copy = out; + for (; i < num_vecs; i++) { + int8x16_t in_vec = vld1q_s8(in_copy); + int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7)); + int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7)); + float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales); + float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales); + + int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15)); + int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15)); + float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales); + float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales); + vst1q_f32(out_copy + 0, out_vec_0_3); + vst1q_f32(out_copy + 4, out_vec_4_7); + vst1q_f32(out_copy + 8, out_vec_8_11); + vst1q_f32(out_copy + 12, out_vec_12_15); + in_copy += kVecSize; + out_copy += kVecSize; + } + i = i * kVecSize; +#endif + for (; i < numel; i++) { + out[i] = (in[i] - zero_point) * scale; + } +} + +float get_scale(const Tensor& scale, size_t channel_ix) { + ET_CHECK_MSG( + (scale.scalar_type() == ScalarType::Double) || + (scale.scalar_type() == ScalarType::Float), + "scale.scalar_type() %" PRId8 " is not double or float type", + static_cast(scale.scalar_type())); + if (scale.scalar_type() == ScalarType::Double) { + return static_cast(scale.const_data_ptr()[channel_ix]); + } else { + return scale.const_data_ptr()[channel_ix]; + } +} + +bool can_use_optimized_dequantize_per_channel( + const Tensor& in, + const ScalarType in_dtype, + exec_aten::optional& out_dtype) { + bool is_contiguous = false; +#ifdef USE_ATEN_LIB + is_contiguous = in.is_contiguous(); +#else + is_contiguous = executorch::runtime::is_contiguous_dim_order( + in.dim_order().data(), in.dim()); +#endif + if (!is_contiguous || (in_dtype != ScalarType::Char) || + (out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) { + return false; + } + return true; +} + +void dequantize_per_channel_optimized( + const Tensor& in, + const Tensor& scales, + const optional& opt_zero_points, + Tensor& out, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType in_dtype, + exec_aten::optional& out_dtype) { + check_dequantize_per_tensor_args( + in, quant_min, quant_max, in_dtype, out_dtype, out); + ET_CHECK_MSG( + in_dtype == ScalarType::Char, + "in.scalar_type() %" PRId8 " is not supported:", + static_cast(in.scalar_type())); + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out_dtype.value() == ScalarType::Float, + "Only float output is supported"); + } + const int8_t* in_data = in.const_data_ptr(); + float* out_data = out.mutable_data_ptr(); + const int64_t* zero_points_data = nullptr; + if (opt_zero_points.has_value()) { + zero_points_data = opt_zero_points.value().const_data_ptr(); + } + const StridesType axis_stride = in.strides()[axis]; + const StridesType outer_stride = in.size(axis) * axis_stride; + apply_over_unpacked_dim( + [in_data, + out_data, + &scales, + zero_points_data, + axis_stride, + outer_stride, + quant_min, + quant_max]( + SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) { + const int8_t* in_data_local = + in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride; + const double scale = get_scale(scales, unpacked_dim_idx); + const int64_t zero_point = zero_points_data != nullptr + ? zero_points_data[unpacked_dim_idx] + : 0; + float* out_data_local = out_data + outer_idx * outer_stride + + unpacked_dim_idx * axis_stride; + dequantize_optimized( + in_data_local, + scale, + zero_point, + out_data_local, + quant_min, + quant_max, + numel); + }, + in, + axis); +} + } // namespace /** @@ -172,19 +354,6 @@ Tensor& dequantize_per_tensor_tensor_args_out( return out; } -float get_scale(const Tensor& scale, size_t channel_ix) { - ET_CHECK_MSG( - (scale.scalar_type() == ScalarType::Double) || - (scale.scalar_type() == ScalarType::Float), - "scale.scalar_type() %" PRId8 " is not double or float type", - static_cast(scale.scalar_type())); - if (scale.scalar_type() == ScalarType::Double) { - return static_cast(scale.const_data_ptr()[channel_ix]); - } else { - return scale.const_data_ptr()[channel_ix]; - } -} - Tensor& dequantize_per_channel_out( const Tensor& input, const Tensor& scale, @@ -229,6 +398,20 @@ Tensor& dequantize_per_channel_out( check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); + if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) { + dequantize_per_channel_optimized( + input, + scale, + opt_zero_points, + out, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + return out; + } + // a list contains all dimensions except axis int64_t dims[kTensorDimensionLimit]; for (int64_t i = 0; i < input.dim() - 1; i++) { diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index 8d23e74e41..676aa32690 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -123,13 +123,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) { EXPECT_TENSOR_EQ(out, expected); } -TEST(OpDequantizeOutTest, DequantizePerChannel) { - et_pal_init(); - TensorFactory tf_byte; +template +void test_per_channel_dtype() { + TensorFactory tf; TensorFactory tf_double; TensorFactory tf_long; - Tensor input = tf_byte.full({3, 2}, 100); + Tensor input = tf.full({3, 2}, 100); Tensor scale = tf_double.make({2}, {0.5, 1}); Tensor zero_point = tf_long.make({2}, {30, 60}); int64_t quant_min = 0; @@ -147,7 +147,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/1, quant_min, quant_max, - ScalarType::Byte, + DTYPE, optional(), out); @@ -168,7 +168,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/0, quant_min, quant_max, - ScalarType::Byte, + DTYPE, optional(), out); @@ -176,7 +176,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { // Test with a different axis out = tfo.zeros({3}); - input = tf_byte.make({3}, {100, 100, 100}); + input = tf.make({3}, {100, 100, 100}); scale = tf_double.make({3}, {0.5, 0.75, 1}); zero_point = tf_long.make({3}, {30, 50, 60}); // (100 - 30) * 0.5 @@ -190,8 +190,42 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { /*axis=*/0, quant_min, quant_max, - ScalarType::Byte, + DTYPE, + optional(), + out); + EXPECT_TENSOR_EQ(out, expected); + + // Test with a different axis + input = tf.full({3, 19}, 100); + out = tfo.zeros({3, 19}); + scale = tf_double.make({3}, {0.5, 0.75, 1}); + zero_point = tf_long.make({3}, {30, 50, 60}); + // (100 - 30) * 0.5 + // (100 - 50) * 0.75 + // (100 - 60) * 1 + expected = tfo.make( + {3, 19}, + {35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, + 35, 35, 35, 35, 35, 35, 35, 37.5, 37.5, 37.5, 37.5, 37.5, + 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, + 37.5, 37.5, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, + 40, 40, 40, 40, 40, 40, 40, 40, 40}); + dequantize_per_channel_out( + input, + scale, + zero_point, + /*axis=*/0, + quant_min, + quant_max, + DTYPE, optional(), out); + EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpDequantizeOutTest, DequantizePerChannel) { + et_pal_init(); + test_per_channel_dtype(); + test_per_channel_dtype(); +}