From a17682de2038418028d2d1f6674c6d1235f902d3 Mon Sep 17 00:00:00 2001 From: cad-audio <86048415+cad-audio@users.noreply.github.com> Date: Thu, 18 Apr 2024 08:04:35 -0700 Subject: [PATCH] Quantize, dequantize optimizations for HiFi targets (#2544) BUG=none --- .../lite/micro/kernels/xtensa/dequantize.cc | 117 ++++++++++++++++++ .../lite/micro/kernels/xtensa/quantize.cc | 85 ++++++++----- 2 files changed, 173 insertions(+), 29 deletions(-) create mode 100644 tensorflow/lite/micro/kernels/xtensa/dequantize.cc diff --git a/tensorflow/lite/micro/kernels/xtensa/dequantize.cc b/tensorflow/lite/micro/kernels/xtensa/dequantize.cc new file mode 100644 index 00000000000..6d731bfdd75 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/dequantize.cc @@ -0,0 +1,117 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/dequantize.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/quantize.h" +#include "tensorflow/lite/kernels/internal/reference/requantize.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/dequantize.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" + +namespace tflite { + +void* DequantizeInit(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(DequantizeOpData)); +} + +TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + DequantizeOpData* data = static_cast(node->user_data); + + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + + // Output type ensured to be kTfLiteFloat32 at the Prepare stage + TFLITE_DCHECK(output->type == kTfLiteFloat32); + + switch (input->type) { + case kTfLiteInt8: { +#if HIFI_VFPU && (defined(HIFI5) || defined(HIFI4)) + int err; + const int8_t* input_data_ptr; + float* output_data_ptr; + const int flat_size = + MatchingFlatSize(tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorShape(output)); + input_data_ptr = tflite::micro::GetTensorData(input); + output_data_ptr = tflite::micro::GetTensorData(output); + + err = xa_nn_elm_dequantize_asym8s_f32( + output_data_ptr, input_data_ptr, data->quantization_params.zero_point, + data->quantization_params.scale, flat_size); + TF_LITE_ENSURE(context, (err == 0)); +#else // HIFI_VFPU && (defined(HIFI5) || defined(HIFI4)) + reference_ops::Dequantize(data->quantization_params, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); +#endif // HIFI_VFPU && (defined(HIFI5) || defined(HIFI4)) + break; + } + case kTfLiteInt16: { +#if HIFI_VFPU && (defined(HIFI5) || defined(HIFI4)) + int err; + const int16_t* input_data_ptr; + float* output_data_ptr; + const RuntimeShape& input_shape = tflite::micro::GetTensorShape(input); + const RuntimeShape& output_shape = tflite::micro::GetTensorShape(output); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + input_data_ptr = tflite::micro::GetTensorData(input); + output_data_ptr = tflite::micro::GetTensorData(output); + err = xa_nn_elm_dequantize_asym16s_f32( + output_data_ptr, input_data_ptr, data->quantization_params.zero_point, + data->quantization_params.scale, flat_size); + TF_LITE_ENSURE(context, (err == 0)); +#else // HIFI_VFPU && (defined(HIFI5) || defined(HIFI4)) + reference_ops::Dequantize(data->quantization_params, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); +#endif // HIFI_VFPU && (defined(HIFI5) || defined(HIFI4)) + break; + } + case kTfLiteUInt8: + reference_ops::Dequantize(data->quantization_params, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + default: + MicroPrintf("Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +TFLMRegistration Register_DEQUANTIZE() { + return tflite::micro::RegisterOp(DequantizeInit, DequantizePrepare, + DequantizeEval); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/quantize.cc b/tensorflow/lite/micro/kernels/xtensa/quantize.cc index 15f5243e063..06d4fbbff19 100644 --- a/tensorflow/lite/micro/kernels/xtensa/quantize.cc +++ b/tensorflow/lite/micro/kernels/xtensa/quantize.cc @@ -75,12 +75,19 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: { int size = ElementCount(*input->dims); - reference_ops::Requantize( - tflite::micro::GetTensorData(input), size, - op_data->requantize_output_multiplier, - op_data->requantize_output_shift, op_data->input_zero_point, - op_data->quantization_params.zero_point, - tflite::micro::GetTensorData(output)); + int32_t zero_point = op_data->quantization_params.zero_point; + const int8_t* input_data_ptr; + int8_t* output_data_ptr; + input_data_ptr = tflite::micro::GetTensorData(input); + output_data_ptr = tflite::micro::GetTensorData(output); + + TF_LITE_ENSURE_EQ( + context, + xa_nn_elm_requantize_asym8s_asym8s( + output_data_ptr, input_data_ptr, op_data->input_zero_point, + zero_point, op_data->requantize_output_shift, + op_data->requantize_output_multiplier, size), + 0); break; } @@ -98,7 +105,6 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt32: { int size = ElementCount(*input->dims); int32_t zero_point = op_data->quantization_params.zero_point; -#if defined(HIFI5) const int8_t* input_data_ptr; int32_t* output_data_ptr; input_data_ptr = tflite::micro::GetTensorData(input); @@ -111,13 +117,6 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) { zero_point, op_data->requantize_output_shift, op_data->requantize_output_multiplier, size), 0); -#else - reference_ops::Requantize( - tflite::micro::GetTensorData(input), size, - op_data->requantize_output_multiplier, - op_data->requantize_output_shift, op_data->input_zero_point, - zero_point, tflite::micro::GetTensorData(output)); -#endif // defined(HIFI5) break; } @@ -149,18 +148,20 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt16: { int size = ElementCount(*input->dims); - reference_ops::Requantize( - tflite::micro::GetTensorData(input), size, - op_data->requantize_output_multiplier, - op_data->requantize_output_shift, op_data->input_zero_point, - op_data->quantization_params.zero_point, - tflite::micro::GetTensorData(output)); + TF_LITE_ENSURE_EQ(context, + xa_nn_elm_requantize_asym16s_asym16s( + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorData(input), + op_data->input_zero_point, + op_data->quantization_params.zero_point, + op_data->requantize_output_shift, + op_data->requantize_output_multiplier, size), + 0); break; } case kTfLiteInt32: { int size = ElementCount(*input->dims); -#if defined(HIFI5) TF_LITE_ENSURE_EQ(context, xa_nn_elm_requantize_asym16s_asym32s( tflite::micro::GetTensorData(output), @@ -170,14 +171,6 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) { op_data->requantize_output_shift, op_data->requantize_output_multiplier, size), 0); -#else - int32_t zero_point = op_data->quantization_params.zero_point; - reference_ops::Requantize( - tflite::micro::GetTensorData(input), size, - op_data->requantize_output_multiplier, - op_data->requantize_output_shift, op_data->input_zero_point, - zero_point, tflite::micro::GetTensorData(output)); -#endif // defined(HIFI5) break; } @@ -228,22 +221,56 @@ TfLiteStatus EvalXtensa(TfLiteContext* context, TfLiteNode* node) { case kTfLiteFloat32: { switch (output->type) { case kTfLiteInt8: { +#if HIFI_VFPU + int size = ElementCount(*input->dims); + int32_t zero_point = op_data->quantization_params.zero_point; + const float* input_data_ptr; + int8_t* output_data_ptr; + input_data_ptr = tflite::micro::GetTensorData(input); + output_data_ptr = tflite::micro::GetTensorData(output); + + TF_LITE_ENSURE_EQ( + context, + xa_nn_elm_quantize_f32_asym8s( + output_data_ptr, input_data_ptr, + static_cast(op_data->quantization_params.scale), + zero_point, size), + 0); +#else // #if HIFI_VFPU reference_ops::AffineQuantize( op_data->quantization_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); +#endif // #if HIFI_VFPU break; } case kTfLiteInt16: { +#if HIFI_VFPU + int size = ElementCount(*input->dims); + int32_t zero_point = op_data->quantization_params.zero_point; + const float* input_data_ptr; + int16_t* output_data_ptr; + input_data_ptr = tflite::micro::GetTensorData(input); + output_data_ptr = tflite::micro::GetTensorData(output); + + TF_LITE_ENSURE_EQ( + context, + xa_nn_elm_quantize_f32_asym16s( + output_data_ptr, input_data_ptr, + static_cast(op_data->quantization_params.scale), + zero_point, size), + 0); +#else // #if HIFI_VFPU reference_ops::AffineQuantize( op_data->quantization_params, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); +#endif // #if HIFI_VFPU break; }