Skip to content

Commit

Permalink
[Executorch][quant] Optimize per channel dequantize
Browse files Browse the repository at this point in the history
Pull Request resolved: #5670

When using quantized kv cache, dequantization routine takes significantly long.
This diff just vectorizes dequant per channel for common case.
ghstack-source-id: 255730818
@exported-using-ghexport

Differential Revision: [D63338858](https://our.internmc.facebook.com/intern/diff/D63338858/)

Co-authored-by: Kimish Patel <[email protected]>
  • Loading branch information
pytorchbot and kimishpatel authored Dec 2, 2024
1 parent 9d084c4 commit ddec0c7
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 21 deletions.
209 changes: 196 additions & 13 deletions kernels/quantized/cpu/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include <algorithm>
#include <cinttypes>
#include <cmath>
#if defined(__aarch64__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif

/**
* For an input tensor, use the scale and zero_point arguments to quantize it.
Expand All @@ -22,6 +25,8 @@ namespace native {
using Tensor = exec_aten::Tensor;
using Scalar = exec_aten::Scalar;
using ScalarType = exec_aten::ScalarType;
using StridesType = exec_aten::StridesType;
using SizesType = exec_aten::SizesType;

namespace {

Expand Down Expand Up @@ -63,6 +68,183 @@ void check_dequantize_per_tensor_args(
quant_max);
}

/**
* Useful to reduce a tensor `in` over a given dimension `dim` using the
* reduce function `fn`, which should have the following signature:
* void fn(const size_t size, const size_t stride, const size_t base_ix)
* where `size` and `stride` are the size and stride of the dimension being
* reduced and `base_ix` is the index of the first element of the reduction.
*/
template <typename Fn>
void apply_over_unpacked_dim(
const Fn& fn,
const exec_aten::Tensor& in,
const int64_t& dim) {
if (in.numel() == 0) {
return;
}

ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension");
ET_CHECK_VALID_DIM(dim, in.dim());

const size_t d = ET_NORMALIZE_IX(dim, in.dim());
const size_t dim_size = in.size(d);
const size_t outer_size = getLeadingDims(in, d);
const size_t inner_size = getTrailingDims(in, d);
// Loop through all outer dimensions
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
// Loop through dim
for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size;
++unpacked_dim_idx) {
fn(inner_size, outer_idx, unpacked_dim_idx);
}
}
}

void dequantize_optimized(
const int8_t* in,
const double scale,
const int64_t zero_point,
float* out,
int64_t quant_min,
int64_t quant_max,
size_t numel) {
ET_CHECK_MSG(
zero_point >= quant_min,
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
zero_point,
quant_min);
ET_CHECK_MSG(
zero_point <= quant_max,
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
zero_point,
quant_max);
size_t i = 0;
#if defined(__aarch64__) || defined(__ARM_NEON)
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
constexpr int32_t kVecSize = 16;
const size_t num_vecs = numel / kVecSize;
const int8_t* in_copy = in;
float* out_copy = out;
for (; i < num_vecs; i++) {
int8x16_t in_vec = vld1q_s8(in_copy);
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);

int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
vst1q_f32(out_copy + 0, out_vec_0_3);
vst1q_f32(out_copy + 4, out_vec_4_7);
vst1q_f32(out_copy + 8, out_vec_8_11);
vst1q_f32(out_copy + 12, out_vec_12_15);
in_copy += kVecSize;
out_copy += kVecSize;
}
i = i * kVecSize;
#endif
for (; i < numel; i++) {
out[i] = (in[i] - zero_point) * scale;
}
}

float get_scale(const Tensor& scale, size_t channel_ix) {
ET_CHECK_MSG(
(scale.scalar_type() == ScalarType::Double) ||
(scale.scalar_type() == ScalarType::Float),
"scale.scalar_type() %" PRId8 " is not double or float type",
static_cast<int8_t>(scale.scalar_type()));
if (scale.scalar_type() == ScalarType::Double) {
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
} else {
return scale.const_data_ptr<float>()[channel_ix];
}
}

bool can_use_optimized_dequantize_per_channel(
const Tensor& in,
const ScalarType in_dtype,
exec_aten::optional<ScalarType>& out_dtype) {
bool is_contiguous = false;
#ifdef USE_ATEN_LIB
is_contiguous = in.is_contiguous();
#else
is_contiguous = executorch::runtime::is_contiguous_dim_order(
in.dim_order().data(), in.dim());
#endif
if (!is_contiguous || (in_dtype != ScalarType::Char) ||
(out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) {
return false;
}
return true;
}

void dequantize_per_channel_optimized(
const Tensor& in,
const Tensor& scales,
const optional<Tensor>& opt_zero_points,
Tensor& out,
int64_t axis,
int64_t quant_min,
int64_t quant_max,
ScalarType in_dtype,
exec_aten::optional<ScalarType>& out_dtype) {
check_dequantize_per_tensor_args(
in, quant_min, quant_max, in_dtype, out_dtype, out);
ET_CHECK_MSG(
in_dtype == ScalarType::Char,
"in.scalar_type() %" PRId8 " is not supported:",
static_cast<int8_t>(in.scalar_type()));
if (out_dtype.has_value()) {
ET_CHECK_MSG(
out_dtype.value() == ScalarType::Float,
"Only float output is supported");
}
const int8_t* in_data = in.const_data_ptr<int8_t>();
float* out_data = out.mutable_data_ptr<float>();
const int64_t* zero_points_data = nullptr;
if (opt_zero_points.has_value()) {
zero_points_data = opt_zero_points.value().const_data_ptr<int64_t>();
}
const StridesType axis_stride = in.strides()[axis];
const StridesType outer_stride = in.size(axis) * axis_stride;
apply_over_unpacked_dim(
[in_data,
out_data,
&scales,
zero_points_data,
axis_stride,
outer_stride,
quant_min,
quant_max](
SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
const int8_t* in_data_local =
in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
const double scale = get_scale(scales, unpacked_dim_idx);
const int64_t zero_point = zero_points_data != nullptr
? zero_points_data[unpacked_dim_idx]
: 0;
float* out_data_local = out_data + outer_idx * outer_stride +
unpacked_dim_idx * axis_stride;
dequantize_optimized(
in_data_local,
scale,
zero_point,
out_data_local,
quant_min,
quant_max,
numel);
},
in,
axis);
}

} // namespace

/**
Expand Down Expand Up @@ -172,19 +354,6 @@ Tensor& dequantize_per_tensor_tensor_args_out(
return out;
}

float get_scale(const Tensor& scale, size_t channel_ix) {
ET_CHECK_MSG(
(scale.scalar_type() == ScalarType::Double) ||
(scale.scalar_type() == ScalarType::Float),
"scale.scalar_type() %" PRId8 " is not double or float type",
static_cast<int8_t>(scale.scalar_type()));
if (scale.scalar_type() == ScalarType::Double) {
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
} else {
return scale.const_data_ptr<float>()[channel_ix];
}
}

Tensor& dequantize_per_channel_out(
const Tensor& input,
const Tensor& scale,
Expand Down Expand Up @@ -229,6 +398,20 @@ Tensor& dequantize_per_channel_out(
check_dequantize_per_tensor_args(
input, quant_min, quant_max, dtype, out_dtype, out);

if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) {
dequantize_per_channel_optimized(
input,
scale,
opt_zero_points,
out,
axis,
quant_min,
quant_max,
dtype,
out_dtype);
return out;
}

// a list contains all dimensions except axis
int64_t dims[kTensorDimensionLimit];
for (int64_t i = 0; i < input.dim() - 1; i++) {
Expand Down
50 changes: 42 additions & 8 deletions kernels/quantized/test/op_dequantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpDequantizeOutTest, DequantizePerChannel) {
et_pal_init();
TensorFactory<ScalarType::Byte> tf_byte;
template <ScalarType DTYPE>
void test_per_channel_dtype() {
TensorFactory<DTYPE> tf;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_byte.full({3, 2}, 100);
Tensor input = tf.full({3, 2}, 100);
Tensor scale = tf_double.make({2}, {0.5, 1});
Tensor zero_point = tf_long.make({2}, {30, 60});
int64_t quant_min = 0;
Expand All @@ -147,7 +147,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
/*axis=*/1,
quant_min,
quant_max,
ScalarType::Byte,
DTYPE,
optional<ScalarType>(),
out);

Expand All @@ -168,15 +168,15 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
/*axis=*/0,
quant_min,
quant_max,
ScalarType::Byte,
DTYPE,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);

// Test with a different axis
out = tfo.zeros({3});
input = tf_byte.make({3}, {100, 100, 100});
input = tf.make({3}, {100, 100, 100});
scale = tf_double.make({3}, {0.5, 0.75, 1});
zero_point = tf_long.make({3}, {30, 50, 60});
// (100 - 30) * 0.5
Expand All @@ -190,8 +190,42 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
/*axis=*/0,
quant_min,
quant_max,
ScalarType::Byte,
DTYPE,
optional<ScalarType>(),
out);
EXPECT_TENSOR_EQ(out, expected);

// Test with a different axis
input = tf.full({3, 19}, 100);
out = tfo.zeros({3, 19});
scale = tf_double.make({3}, {0.5, 0.75, 1});
zero_point = tf_long.make({3}, {30, 50, 60});
// (100 - 30) * 0.5
// (100 - 50) * 0.75
// (100 - 60) * 1
expected = tfo.make(
{3, 19},
{35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
35, 35, 35, 35, 35, 35, 35, 37.5, 37.5, 37.5, 37.5, 37.5,
37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5,
37.5, 37.5, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
40, 40, 40, 40, 40, 40, 40, 40, 40});
dequantize_per_channel_out(
input,
scale,
zero_point,
/*axis=*/0,
quant_min,
quant_max,
DTYPE,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpDequantizeOutTest, DequantizePerChannel) {
et_pal_init();
test_per_channel_dtype<ScalarType::Byte>();
test_per_channel_dtype<ScalarType::Char>();
}

0 comments on commit ddec0c7

Please sign in to comment.