Skip to content

Commit

Permalink
[QNN EP] Enable Expand op (microsoft#18234)
Browse files Browse the repository at this point in the history
### Description
Enable Expand Op.
There no directly mapping from Onnx Expand op to QNN. Need to use
ElementWiseMultiply to do the data broadcast. Basically create the 2nd
input with value 1.0 and use the shape data from Expand op.
  • Loading branch information
HectorSVC authored Nov 7, 2023
1 parent 3b63d85 commit ad34c67
Show file tree
Hide file tree
Showing 13 changed files with 609 additions and 347 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
return {{"Gather", {}},
{"Reshape", {}},
{"Expand", {}},
{"Flatten", {}},
{"Transpose", {}},
{"MaxPool", {12}},
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
{
CreatePadOpBuilder("Pad", *this);
}

{
CreateExpandOpBuilder("Expand", *this);
}
}

const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,7 @@ void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations

void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ class BaseOpBuilder : public IOpBuilder {

{"LRN", QNN_OP_LRN},

{"Pad", QNN_OP_PAD}};
{"Pad", QNN_OP_PAD},

{"Expand", QNN_OP_ELEMENT_WISE_MULTIPLY}};
auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type);
ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end());
return it->second;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <utility>

#include "core/providers/common.h"
#include "core/util/qmath.h"
#include "core/providers/shared/utils/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
Expand All @@ -32,57 +31,6 @@ class BatchNormOpBuilder : public BaseOpBuilder {
const NodeUnit& node_unit,
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;

std::pair<float, float> CheckMinMax(float rmin, float rmax) const {
// Ensure a minimum range of 0.0001 (required by QNN)
rmax = std::max(rmax, rmin + 0.0001f);

// Both QNN and ORT require the range to include 0.0f
rmin = std::min(rmin, 0.0f);
rmax = std::max(rmax, 0.0f);

return std::make_pair(rmin, rmax);
}

template <typename T>
Status GetQminQmax(const Qnn_DataType_t qnn_data_type,
T& qmin,
T& qmax) const {
if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) {
qmin = static_cast<T>(std::numeric_limits<int8_t>::min());
qmax = static_cast<T>(std::numeric_limits<int8_t>::max());
} else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) {
qmin = static_cast<T>(std::numeric_limits<uint8_t>::min());
qmax = static_cast<T>(std::numeric_limits<uint8_t>::max());
} else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) {
qmin = static_cast<T>(std::numeric_limits<int16_t>::min());
qmax = static_cast<T>(std::numeric_limits<int16_t>::max());
} else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) {
qmin = static_cast<T>(std::numeric_limits<uint16_t>::min());
qmax = static_cast<T>(std::numeric_limits<uint16_t>::max());
} else {
ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type);
}
return Status::OK();
}

Status GetQuantParams(float rmin,
float rmax,
const Qnn_DataType_t qnn_data_type,
float& scale,
int& zero_point) const {
std::tie(rmin, rmax) = CheckMinMax(rmin, rmax);
float qmin = 0.0f;
float qmax = 255.0f;
ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax));

scale = (rmax - rmin) / (qmax - qmin);
const float initial_zero_point = qmin - (rmin / scale);
zero_point = static_cast<int>(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point)));
// To match QNN quantization definition
zero_point = 0 - zero_point;
return Status::OK();
}

inline Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type,
const uint8_t* raw_ptr,
double& value,
Expand Down Expand Up @@ -303,38 +251,6 @@ class BatchNormOpBuilder : public BaseOpBuilder {
return Status::OK();
}

inline double Dequantize(const OnnxInputInfo& info,
const double quant_value) const {
auto offset = static_cast<double>(info.quant_param.scaleOffsetEncoding.offset);
auto scale = static_cast<double>(info.quant_param.scaleOffsetEncoding.scale);
return (quant_value + offset) * scale;
}

template <typename T>
inline T Saturate(const T qmax,
const T qmin,
const T quant_value) const {
if (quant_value > qmax) {
return qmax;
} else if (quant_value < qmin) {
return qmin;
} else {
return quant_value;
}
}

inline Status Quantize(const double double_value,
const float scale,
const int zero_point,
const Qnn_DataType_t qnn_data_type,
int& quant_value) const {
int qmin = 0;
int qmax = 255;
ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax));
quant_value = Saturate(qmax, qmin, static_cast<int>(std::round((double_value / scale) - zero_point)));
return Status::OK();
}

Status PreprocessMean(const OnnxInputInfo& mean_info,
const bool is_npu_backend,
const uint8_t* mean_raw_ptr,
Expand All @@ -349,7 +265,10 @@ class BatchNormOpBuilder : public BaseOpBuilder {
for (; i < static_cast<int>(channel); ++i) {
double mean_value = 0.0;
ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset));
mean_out[i] = (is_npu_backend) ? Dequantize(mean_info, mean_value) : mean_value;
mean_out[i] = (is_npu_backend) ? utils::Dequantize(mean_info.quant_param.scaleOffsetEncoding.offset,
mean_info.quant_param.scaleOffsetEncoding.scale,
mean_value)
: mean_value;
}
return Status::OK();
}
Expand All @@ -369,7 +288,10 @@ class BatchNormOpBuilder : public BaseOpBuilder {
for (; i < static_cast<int>(channel); ++i) {
double var_value = 0.0;
ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset));
std_out[i] = (is_npu_backend) ? Dequantize(var_info, var_value) : var_value;
std_out[i] = (is_npu_backend) ? utils::Dequantize(var_info.quant_param.scaleOffsetEncoding.offset,
var_info.quant_param.scaleOffsetEncoding.scale,
var_value)
: var_value;
std_out[i] = std::sqrt(std_out[i] + static_cast<double>(epsilon));
}
return Status::OK();
Expand All @@ -392,7 +314,10 @@ class BatchNormOpBuilder : public BaseOpBuilder {
for (; i < static_cast<int>(channel); ++i) {
double scale_value = 0.0;
ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset));
scale_out[i] = (is_npu_backend) ? Dequantize(scale_info, scale_value) : scale_value;
scale_out[i] = (is_npu_backend) ? utils::Dequantize(scale_info.quant_param.scaleOffsetEncoding.offset,
scale_info.quant_param.scaleOffsetEncoding.scale,
scale_value)
: scale_value;
scale_out[i] = scale_out[i] / std_double_tensor[i];
rmax = std::max(rmax, scale_out[i]);
rmin = std::min(rmin, scale_out[i]);
Expand All @@ -418,7 +343,10 @@ class BatchNormOpBuilder : public BaseOpBuilder {
for (; i < static_cast<int>(channel); ++i) {
double bias_value = 0.0;
ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset));
bias_out[i] = (is_npu_backend) ? Dequantize(bias_info, bias_value) : bias_value;
bias_out[i] = (is_npu_backend) ? utils::Dequantize(bias_info.quant_param.scaleOffsetEncoding.offset,
bias_info.quant_param.scaleOffsetEncoding.scale,
bias_value)
: bias_value;
bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]);
rmax = std::max(rmax, bias_out[i]);
rmin = std::min(rmin, bias_out[i]);
Expand All @@ -437,17 +365,17 @@ class BatchNormOpBuilder : public BaseOpBuilder {
raw_tensor.resize(double_tensor.size());
float scale = 0.0f;
int zero_point = 0;
ORT_RETURN_IF_ERROR(GetQuantParams(static_cast<float>(rmin),
static_cast<float>(rmax),
info.qnn_data_type,
scale,
zero_point));
ORT_RETURN_IF_ERROR(utils::GetQuantParams(static_cast<float>(rmin),
static_cast<float>(rmax),
info.qnn_data_type,
scale,
zero_point));
quant_param = QNN_QUANTIZE_PARAMS_INIT;
utils::InitializeQuantizeParam(quant_param, true, scale, zero_point);
for (size_t i = 0; i < double_tensor.size(); ++i) {
// onnx only supports 8 bits quantization
int quant_value_int = 0;
ORT_RETURN_IF_ERROR(Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int));
ORT_RETURN_IF_ERROR(utils::Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int));
if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) {
raw_tensor[i] = static_cast<uint8_t>(quant_value_int);
} else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) {
Expand Down
139 changes: 139 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/qnn/builder/qnn_utils.h"
#include "core/common/safeint.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace qnn {

class ExpandOpBuilder : public BaseOpBuilder {
public:
ExpandOpBuilder() : BaseOpBuilder("ExpandOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExpandOpBuilder);

protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

template <typename T>
void FillShapeInputData(std::vector<uint8_t>& shape_data, int shape_size, T ini_value) {
shape_data.resize(shape_size * sizeof(T));
T* shape_data_float = reinterpret_cast<T*>(shape_data.data());
std::fill(shape_data_float, shape_data_float + shape_size, ini_value);
}

// Use ElementWiseMultiply to implement data broadcast
// Get the shape data, and create a initializer input with value 1 and same shape
// input[0] * input[1]
Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const {
ORT_UNUSED_PARAMETER(do_op_validation);
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "Expand should has 2 inputs!");

ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));

// Process shape input
const auto& input_name = inputs[1].node_arg.Name();
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);
ORT_RETURN_IF_NOT(is_initializer_input, "QNN doesn't support dynamic shape.");

std::vector<uint32_t> shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, shape), "Cannot get shape");
uint32_t shape_rank = shape[0];
std::vector<uint8_t> unpacked_tensor;
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
const int64_t* shape_data_int64 = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
std::vector<uint32_t> input_shape(shape_rank, 0);
std::transform(shape_data_int64, shape_data_int64 + shape_rank, input_shape.begin(),
[](int64_t item) { return SafeInt<uint32_t>(item); });
int shape_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<uint32_t>());

std::vector<uint8_t> shape_data;
bool is_quantized_tensor = inputs[0].quant_param.has_value();
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32;
const auto* type_proto = inputs[0].node_arg.TypeAsProto();
Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;
if (is_quantized_tensor) {
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(true, type_proto, qnn_data_type));
float scale = 0.0f;
int zero_point = 0;
float rmax = 1.0f;
float rmin = 1.0f;
ORT_RETURN_IF_ERROR(utils::GetQuantParams(rmin,
rmax,
qnn_data_type,
scale,
zero_point));
utils::InitializeQuantizeParam(quantize_param, true, scale, zero_point);
int quant_value_int = 0;
double ini_value = 1.0;
ORT_RETURN_IF_ERROR(utils::Quantize(ini_value, scale, zero_point, qnn_data_type, quant_value_int));
switch (qnn_data_type) {
case QNN_DATATYPE_SFIXED_POINT_8: {
FillShapeInputData(shape_data, shape_size, static_cast<int8_t>(quant_value_int));
break;
}
case QNN_DATATYPE_UFIXED_POINT_8: {
FillShapeInputData(shape_data, shape_size, static_cast<uint8_t>(quant_value_int));
break;
}
case QNN_DATATYPE_UFIXED_POINT_16: {
FillShapeInputData(shape_data, shape_size, static_cast<uint16_t>(quant_value_int));
break;
}
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported.");
} // switch
} else {
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type));
switch (qnn_data_type) {
case QNN_DATATYPE_FLOAT_32: {
FillShapeInputData(shape_data, shape_size, static_cast<float>(1.0));
break;
}
case QNN_DATATYPE_INT_32: {
FillShapeInputData(shape_data, shape_size, static_cast<int32_t>(1));
break;
}
case QNN_DATATYPE_UINT_32: {
FillShapeInputData(shape_data, shape_size, static_cast<uint32_t>(1));
break;
}
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported.");
} // switch
} // if-else

const std::string& output_name = node_unit.Outputs()[0].node_arg.Name();
std::string shape_input_name(input_name + "_" + output_name);
QnnTensorWrapper input_tensorwrapper(shape_input_name, QNN_TENSOR_TYPE_STATIC, qnn_data_type, quantize_param,
std::move(input_shape), std::move(shape_data));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");

input_names.push_back(shape_input_name);

return Status::OK();
}

void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.AddOpBuilder(op_type, std::make_unique<ExpandOpBuilder>());
}

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const {
ORT_UNUSED_PARAMETER(do_op_validation);
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "Gather should has 2 inputs at least!");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));
Expand Down
Loading

0 comments on commit ad34c67

Please sign in to comment.