Skip to content

Commit

Permalink
Fix dequantize per channel to handle double scale type (#5524)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5524

ghstack-source-id: 244685036
exported-using-ghexport

Reviewed By: swolchok

Differential Revision: D62301839

fbshipit-source-id: ac969b80fda97adacef0ad6afab3bc0cf34050b0
  • Loading branch information
kimishpatel authored and facebook-github-bot committed Sep 26, 2024
1 parent 985f92d commit 9d224a5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
25 changes: 16 additions & 9 deletions kernels/quantized/cpu/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@ Tensor& dequantize_per_tensor_tensor_args_out(
return out;
}

float get_scale(const Tensor& scale, size_t channel_ix) {
ET_CHECK_MSG(
(scale.scalar_type() == ScalarType::Double) ||
(scale.scalar_type() == ScalarType::Float),
"scale.scalar_type() %" PRId8 " is not double or float type",
static_cast<int8_t>(scale.scalar_type()));
if (scale.scalar_type() == ScalarType::Double) {
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
} else {
return scale.const_data_ptr<float>()[channel_ix];
}
}

Tensor& dequantize_per_channel_out(
const Tensor& input,
const Tensor& scale,
Expand Down Expand Up @@ -195,11 +208,6 @@ Tensor& dequantize_per_channel_out(
err == torch::executor::Error::Ok,
"Failed to resize out Tensor in dequantize_per_channel_out");

ET_CHECK_MSG(
scale.scalar_type() == ScalarType::Float,
"scale.scalar_type() %" PRId8 " is not float type",
static_cast<int8_t>(scale.scalar_type()));

ET_CHECK_MSG(
scale.numel() == input.size(axis),
"scale.numel() %zd != input.size(axis) %zd",
Expand Down Expand Up @@ -232,7 +240,6 @@ Tensor& dequantize_per_channel_out(
dims[i] = i + 1;
}
}
const float* scale_data = scale.const_data_ptr<float>();
const int64_t* zero_point_data;
if (opt_zero_points.has_value()) {
zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
Expand Down Expand Up @@ -260,11 +267,11 @@ Tensor& dequantize_per_channel_out(
axis == 0, "Axis must be 0 for a single dimensional tensors"); \
const optional<int64_t> dim; \
apply_over_dim( \
[input_data_ptr, out_data_ptr, scale_data, zero_point_data]( \
[input_data_ptr, out_data_ptr, zero_point_data, &scale]( \
size_t numel, size_t stride, size_t base_ix) { \
for (size_t i = 0; i < numel; i++) { \
size_t current_ix = base_ix * stride + i; \
float _scale = scale_data[current_ix]; \
float _scale = get_scale(scale, current_ix); \
int64_t zero_point = 0; \
if (zero_point_data != nullptr) { \
zero_point = zero_point_data[current_ix]; \
Expand All @@ -280,7 +287,7 @@ Tensor& dequantize_per_channel_out(
break; \
} \
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
float _scale = scale_data[channel_ix]; \
float _scale = get_scale(scale, channel_ix); \
int64_t _zero_point = 0; \
if (zero_point_data != nullptr) { \
_zero_point = zero_point_data[channel_ix]; \
Expand Down
13 changes: 9 additions & 4 deletions kernels/quantized/test/op_dequantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/test/utils/DeathTest.h>

#include <gtest/gtest.h>
Expand Down Expand Up @@ -57,10 +58,12 @@ void test_dtype() {
}

TEST(OpDequantizeOutTest, AllDtypesSupported) {
et_pal_init();
test_dtype<ScalarType::Byte>();
}

TEST(OpDequantizeOutTest, NonWholeNumbers) {
et_pal_init();
TensorFactory<ScalarType::Byte> tf;

Tensor input = tf.full({3, 5}, 100);
Expand All @@ -87,6 +90,7 @@ TEST(OpDequantizeOutTest, NonWholeNumbers) {
}

TEST(OpDequantizeOutTest, TensorArgOverload) {
et_pal_init();
TensorFactory<ScalarType::Byte> tf_byte;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;
Expand Down Expand Up @@ -115,12 +119,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
}

TEST(OpDequantizeOutTest, DequantizePerChannel) {
et_pal_init();
TensorFactory<ScalarType::Byte> tf_byte;
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_byte.full({3, 2}, 100);
Tensor scale = tf_float.make({2}, {0.5, 1});
Tensor scale = tf_double.make({2}, {0.5, 1});
Tensor zero_point = tf_long.make({2}, {30, 60});
int64_t quant_min = 0;
int64_t quant_max = 255;
Expand All @@ -145,7 +150,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {

// Test with a different axis
out = tfo.zeros({3, 2});
scale = tf_float.make({3}, {0.5, 0.75, 1});
scale = tf_double.make({3}, {0.5, 0.75, 1});
zero_point = tf_long.make({3}, {30, 50, 60});
// (100 - 30) * 0.5
// (100 - 50) * 0.75
Expand All @@ -167,7 +172,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
// Test with a different axis
out = tfo.zeros({3});
input = tf_byte.make({3}, {100, 100, 100});
scale = tf_float.make({3}, {0.5, 0.75, 1});
scale = tf_double.make({3}, {0.5, 0.75, 1});
zero_point = tf_long.make({3}, {30, 50, 60});
// (100 - 30) * 0.5
// (100 - 50) * 0.75
Expand Down

0 comments on commit 9d224a5

Please sign in to comment.