Skip to content

Commit

Permalink
Quantize, dequantize optimizations for HiFi targets (#2544)
Browse files Browse the repository at this point in the history
BUG=none
  • Loading branch information
cad-audio authored Apr 18, 2024
1 parent 26fbc4b commit a17682d
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 29 deletions.
117 changes: 117 additions & 0 deletions tensorflow/lite/micro/kernels/xtensa/dequantize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/kernels/internal/reference/dequantize.h"

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/quantize.h"
#include "tensorflow/lite/kernels/internal/reference/requantize.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/dequantize.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"

namespace tflite {

void* DequantizeInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(DequantizeOpData));
}

TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
DequantizeOpData* data = static_cast<DequantizeOpData*>(node->user_data);

const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);

// Output type ensured to be kTfLiteFloat32 at the Prepare stage
TFLITE_DCHECK(output->type == kTfLiteFloat32);

switch (input->type) {
case kTfLiteInt8: {
#if HIFI_VFPU && (defined(HIFI5) || defined(HIFI4))
int err;
const int8_t* input_data_ptr;
float* output_data_ptr;
const int flat_size =
MatchingFlatSize(tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorShape(output));
input_data_ptr = tflite::micro::GetTensorData<int8_t>(input);
output_data_ptr = tflite::micro::GetTensorData<float>(output);

err = xa_nn_elm_dequantize_asym8s_f32(
output_data_ptr, input_data_ptr, data->quantization_params.zero_point,
data->quantization_params.scale, flat_size);
TF_LITE_ENSURE(context, (err == 0));
#else // HIFI_VFPU && (defined(HIFI5) || defined(HIFI4))
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
#endif // HIFI_VFPU && (defined(HIFI5) || defined(HIFI4))
break;
}
case kTfLiteInt16: {
#if HIFI_VFPU && (defined(HIFI5) || defined(HIFI4))
int err;
const int16_t* input_data_ptr;
float* output_data_ptr;
const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input);
const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output);
const int flat_size = MatchingFlatSize(input_shape, output_shape);
input_data_ptr = tflite::micro::GetTensorData<int16_t>(input);
output_data_ptr = tflite::micro::GetTensorData<float>(output);
err = xa_nn_elm_dequantize_asym16s_f32(
output_data_ptr, input_data_ptr, data->quantization_params.zero_point,
data->quantization_params.scale, flat_size);
TF_LITE_ENSURE(context, (err == 0));
#else // HIFI_VFPU && (defined(HIFI5) || defined(HIFI4))
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
#endif // HIFI_VFPU && (defined(HIFI5) || defined(HIFI4))
break;
}
case kTfLiteUInt8:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
default:
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

return kTfLiteOk;
}

TFLMRegistration Register_DEQUANTIZE() {
return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare,
DequantizeEval);
}

} // namespace tflite
85 changes: 56 additions & 29 deletions tensorflow/lite/micro/kernels/xtensa/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,19 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) {

case kTfLiteInt8: {
int size = ElementCount(*input->dims);
reference_ops::Requantize(
tflite::micro::GetTensorData<int8_t>(input), size,
op_data->requantize_output_multiplier,
op_data->requantize_output_shift, op_data->input_zero_point,
op_data->quantization_params.zero_point,
tflite::micro::GetTensorData<int8_t>(output));
int32_t zero_point = op_data->quantization_params.zero_point;
const int8_t* input_data_ptr;
int8_t* output_data_ptr;
input_data_ptr = tflite::micro::GetTensorData<int8_t>(input);
output_data_ptr = tflite::micro::GetTensorData<int8_t>(output);

TF_LITE_ENSURE_EQ(
context,
xa_nn_elm_requantize_asym8s_asym8s(
output_data_ptr, input_data_ptr, op_data->input_zero_point,
zero_point, op_data->requantize_output_shift,
op_data->requantize_output_multiplier, size),
0);
break;
}

Expand All @@ -98,7 +105,6 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32: {
int size = ElementCount(*input->dims);
int32_t zero_point = op_data->quantization_params.zero_point;
#if defined(HIFI5)
const int8_t* input_data_ptr;
int32_t* output_data_ptr;
input_data_ptr = tflite::micro::GetTensorData<int8_t>(input);
Expand All @@ -111,13 +117,6 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) {
zero_point, op_data->requantize_output_shift,
op_data->requantize_output_multiplier, size),
0);
#else
reference_ops::Requantize(
tflite::micro::GetTensorData<int8_t>(input), size,
op_data->requantize_output_multiplier,
op_data->requantize_output_shift, op_data->input_zero_point,
zero_point, tflite::micro::GetTensorData<int32_t>(output));
#endif // defined(HIFI5)
break;
}

Expand Down Expand Up @@ -149,18 +148,20 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) {

case kTfLiteInt16: {
int size = ElementCount(*input->dims);
reference_ops::Requantize(
tflite::micro::GetTensorData<int16_t>(input), size,
op_data->requantize_output_multiplier,
op_data->requantize_output_shift, op_data->input_zero_point,
op_data->quantization_params.zero_point,
tflite::micro::GetTensorData<int16_t>(output));
TF_LITE_ENSURE_EQ(context,
xa_nn_elm_requantize_asym16s_asym16s(
tflite::micro::GetTensorData<int16_t>(output),
tflite::micro::GetTensorData<int16_t>(input),
op_data->input_zero_point,
op_data->quantization_params.zero_point,
op_data->requantize_output_shift,
op_data->requantize_output_multiplier, size),
0);
break;
}

case kTfLiteInt32: {
int size = ElementCount(*input->dims);
#if defined(HIFI5)
TF_LITE_ENSURE_EQ(context,
xa_nn_elm_requantize_asym16s_asym32s(
tflite::micro::GetTensorData<int32_t>(output),
Expand All @@ -170,14 +171,6 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) {
op_data->requantize_output_shift,
op_data->requantize_output_multiplier, size),
0);
#else
int32_t zero_point = op_data->quantization_params.zero_point;
reference_ops::Requantize(
tflite::micro::GetTensorData<int16_t>(input), size,
op_data->requantize_output_multiplier,
op_data->requantize_output_shift, op_data->input_zero_point,
zero_point, tflite::micro::GetTensorData<int32_t>(output));
#endif // defined(HIFI5)
break;
}

Expand Down Expand Up @@ -228,22 +221,56 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32: {
switch (output->type) {
case kTfLiteInt8: {
#if HIFI_VFPU
int size = ElementCount(*input->dims);
int32_t zero_point = op_data->quantization_params.zero_point;
const float* input_data_ptr;
int8_t* output_data_ptr;
input_data_ptr = tflite::micro::GetTensorData<float>(input);
output_data_ptr = tflite::micro::GetTensorData<int8_t>(output);

TF_LITE_ENSURE_EQ(
context,
xa_nn_elm_quantize_f32_asym8s(
output_data_ptr, input_data_ptr,
static_cast<float>(op_data->quantization_params.scale),
zero_point, size),
0);
#else // #if HIFI_VFPU
reference_ops::AffineQuantize(
op_data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
#endif // #if HIFI_VFPU
break;
}

case kTfLiteInt16: {
#if HIFI_VFPU
int size = ElementCount(*input->dims);
int32_t zero_point = op_data->quantization_params.zero_point;
const float* input_data_ptr;
int16_t* output_data_ptr;
input_data_ptr = tflite::micro::GetTensorData<float>(input);
output_data_ptr = tflite::micro::GetTensorData<int16_t>(output);

TF_LITE_ENSURE_EQ(
context,
xa_nn_elm_quantize_f32_asym16s(
output_data_ptr, input_data_ptr,
static_cast<float>(op_data->quantization_params.scale),
zero_point, size),
0);
#else // #if HIFI_VFPU
reference_ops::AffineQuantize(
op_data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
#endif // #if HIFI_VFPU
break;
}

Expand Down

0 comments on commit a17682d

Please sign in to comment.