Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN EP] Add support for GatherElements #15966

Merged
merged 21 commits into from
Aug 19, 2024
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bdbf9f8
Add support for GatherElements
adrianlizarraga May 11, 2023
23f90f2
Add tests for GatherElements
adrianlizarraga May 11, 2023
2d9be4a
Add support for GatherND and corresponding unit tests
adrianlizarraga May 16, 2023
e29f309
Merge latest commits from main
adrianlizarraga May 16, 2023
daf2905
Revert use of QNN operator macros until we can add more unit tests
adrianlizarraga May 16, 2023
6af2b56
Wrap long lines; improve zero point for types other than uint8_t
adrianlizarraga May 16, 2023
617c73e
Reveal another bug with GatherElements via a unit test
adrianlizarraga May 17, 2023
c9d1b37
Merge main and fix conflicts just to see where we stand. Need to upda…
adrianlizarraga Sep 22, 2023
0bd42bd
Merge with latest main branch
adrianlizarraga Aug 9, 2024
21b21dd
Update code to support negative static indices
adrianlizarraga Aug 9, 2024
3728b74
Remove GatherND traces
adrianlizarraga Aug 10, 2024
e183bd3
Update tests
adrianlizarraga Aug 12, 2024
1d327fd
Remove ORT_UNUSED_PARAM
adrianlizarraga Aug 12, 2024
d5264c6
Merge branch 'main' into adrianl/qnn-support-gatherelems-gathernd
adrianlizarraga Aug 12, 2024
20be096
Support QDQ GatherElements in quantization tool
adrianlizarraga Aug 12, 2024
32d0583
Disable ONNX test with negative gather elem indices for QNN
adrianlizarraga Aug 12, 2024
adad6d7
Add unit test for inaccurate GatherElems configuration with 2M indices
adrianlizarraga Aug 12, 2024
e934d02
Add py quantization tool unittest for QDQ GatherElements
adrianlizarraga Aug 12, 2024
0cc6285
fix warning as error
adrianlizarraga Aug 12, 2024
dbd1fb9
Merge branch 'main' into adrianl/qnn-support-gatherelems-gathernd
adrianlizarraga Aug 13, 2024
19ea8e1
Flip ternary values
adrianlizarraga Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops
// output Q have the same scale and zero_point.
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
return {{"Gather", {}},
{"GatherElements", {}},
{"Reshape", {}},
{"Expand", {}},
{"Flatten", {}},
Original file line number Diff line number Diff line change
@@ -110,6 +110,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() {

{
CreateGatherOpBuilder("Gather", *this);
CreateGatherOpBuilder("GatherElements", *this);
}

{
Original file line number Diff line number Diff line change
@@ -122,6 +122,7 @@ class BaseOpBuilder : public IOpBuilder {
{"Exp", QNN_OP_ELEMENT_WISE_EXP},
{"Floor", QNN_OP_ELEMENT_WISE_FLOOR},
{"Gather", QNN_OP_GATHER},
{"GatherElements", QNN_OP_GATHER_ELEMENTS},
{"Greater", QNN_OP_ELEMENT_WISE_GREATER},
{"GreaterOrEqual", QNN_OP_ELEMENT_WISE_GREATER_EQUAL},
{"Less", QNN_OP_ELEMENT_WISE_LESS},
225 changes: 165 additions & 60 deletions onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cassert>
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
@@ -13,11 +14,16 @@
namespace onnxruntime {
namespace qnn {

// Handles Gather and GatherElements
class GatherOpBuilder : public BaseOpBuilder {
public:
GatherOpBuilder() : BaseOpBuilder("GatherOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GatherOpBuilder);

Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
@@ -32,100 +38,199 @@ class GatherOpBuilder : public BaseOpBuilder {
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
Status GatherOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "Gather should has 2 inputs at least!");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));
const logging::Logger& logger) const {
// On QNN CPU backend, the QNN validator does not properly reject unsupported input shapes.
// This causes a Qnn graph execution error. So, reject those configs here.
// We should consider not using QNN CPU backend for onnxruntime unit tests.
const std::string& op_type = node_unit.OpType();
if (qnn_model_wrapper.GetQnnBackendType() == QnnBackendType::CPU && op_type == "GatherElements") {
const auto& input0 = node_unit.Inputs()[0];
std::vector<uint32_t> input0_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input0.node_arg, input0_shape),
"Cannot get input[0] shape for ", op_type, " node ", node_unit.Name());

const size_t input0_rank = input0_shape.size();
ORT_RETURN_IF_NOT(input0_rank > 1 && input0_rank <= 4,
"QNN CPU backend does not support ", op_type, " with input[0] of rank ", input0_rank);
}

return BaseOpBuilder::IsOpSupported(qnn_model_wrapper, node_unit, logger);
}

// Makes negative indices positive and converts int64 indices to another integer type (typically int32 or uint32).
// The input and output are both represented as byte arrays.
template <typename SrcType, typename DstType>
static bool FixStaticIndices(const std::vector<uint8_t>& onnx_bytes,
int64_t input0_axis_dim,
/*out*/ std::vector<uint8_t>& qnn_bytes) {
const size_t num_elems = onnx_bytes.size() / sizeof(SrcType);
gsl::span<const SrcType> onnx_indices{reinterpret_cast<const SrcType*>(onnx_bytes.data()), num_elems};

qnn_bytes.resize(num_elems * sizeof(DstType));
gsl::span<DstType> qnn_indices{reinterpret_cast<DstType*>(qnn_bytes.data()), num_elems};

for (size_t i = 0; i < num_elems; i++) {
SrcType onnx_index = onnx_indices[i];

// Try to make a negative index positive by adding rank.
if (onnx_index < 0) {
onnx_index += static_cast<SrcType>(input0_axis_dim);
}

if (onnx_index < 0 || static_cast<int64_t>(onnx_index) >= input0_axis_dim) {
return false; // QNN does not support out-of-bounds indices.
}

qnn_indices[i] = static_cast<DstType>(onnx_index);
}

return true;
}

// Gets the size of input0 on the axis dimension.
static Status GetInpu0AxisDimValue(const QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
int64_t default_axis_value,
/*out*/ int64_t& axis_dim_value) {
const auto& input0 = node_unit.Inputs()[0];
std::vector<uint32_t> input0_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input0.node_arg, input0_shape),
"Cannot get shape for ", node_unit.OpType(), " input[0] ", input0.node_arg.Name());

int64_t rank = static_cast<int64_t>(input0_shape.size());
NodeAttrHelper node_helper(node_unit);
int64_t onnx_axis = node_helper.Get("axis", default_axis_value);
if (onnx_axis < 0) {
onnx_axis += rank;
}
ORT_RETURN_IF_NOT((onnx_axis >= 0 && onnx_axis < static_cast<int64_t>(input0_shape.size())),
"QNN requires axis range [0, rank-1] for ", node_unit.OpType());

axis_dim_value = static_cast<int64_t>(input0_shape[onnx_axis]);

// Process indices
const auto& input_name = inputs[1].node_arg.Name();
return Status::OK();
}

// Processes the indices input to Gather operators.
//
// In general, QNN only supports int32/uint32 indices. QNN EP has to add Cast for dynamic int64 indices or
// convert static int64 indices to int32.
//
// The HTP backend only supports dynamic int64 indices if they are a graph input.
static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
const NodeUnitIODef& indices_input,
int64_t input0_axis_dim,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) {
const auto& input_name = indices_input.node_arg.Name();
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
input_names.push_back(input_name);
return Status::OK();
}

std::string indices_input_name(input_name);
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_INT_32;
const auto* type_proto = inputs[1].node_arg.TypeAsProto();
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type));

std::vector<uint8_t> unpacked_tensor;
std::vector<uint8_t> gather_indices;
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);

// Gather input 0 is quantized tensor, input 1 (indices) is int64, this is not supported by QNN
bool is_quantized_tensor = inputs[0].quant_param.has_value();
ORT_RETURN_IF(is_quantized_tensor && qnn_data_type == QNN_DATATYPE_INT_64 && !is_initializer_input,
"HTP backend doesn't support any int64 data type.");

if (is_initializer_input) {
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
if (qnn_data_type == QNN_DATATYPE_INT_64) {
// Convert initializer from int64 to int32
size_t size = unpacked_tensor.size() / sizeof(int64_t);
const int64_t* gather_indices_int64 = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
gather_indices.resize(size * sizeof(int32_t));
int32_t* gather_indices_int32 = reinterpret_cast<int32_t*>(gather_indices.data());
std::transform(gather_indices_int64, gather_indices_int64 + size, gather_indices_int32,
[](int64_t item) { return SafeInt<uint32_t>(item); });
TensorInfo indices_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(indices_input, indices_info));

const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType());
const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name);
ORT_RETURN_IF(is_npu_backend &&
(indices_info.qnn_data_type == QNN_DATATYPE_INT_64) &&
!(indices_info.is_initializer || is_graph_input),
"HTP backend doesn't support a Gather* op with a dynamic int64 input activation ",
"unless it is a graph input.");

std::vector<uint8_t> qnn_indices_bytes;

// Get raw bytes for static indices.
// If indices are int64, convert them to int32 and update indices_info.qnn_data_type.
if (indices_info.is_initializer) {
std::vector<uint8_t> onnx_indices_bytes;
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*indices_info.initializer_tensor, onnx_indices_bytes));

if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
ORT_RETURN_IF_NOT((FixStaticIndices<int64_t, int32_t>(onnx_indices_bytes, input0_axis_dim, qnn_indices_bytes)),
"QNN does not support negative index values for Gather* ops");
indices_info.qnn_data_type = QNN_DATATYPE_INT_32;
} else if (indices_info.qnn_data_type == QNN_DATATYPE_INT_32) {
ORT_RETURN_IF_NOT((FixStaticIndices<int32_t, int32_t>(onnx_indices_bytes, input0_axis_dim, qnn_indices_bytes)),
"QNN does not support negative index values for Gather* ops");
} else {
gather_indices = std::move(unpacked_tensor);
qnn_indices_bytes = std::move(onnx_indices_bytes);
}
qnn_data_type = QNN_DATATYPE_INT_32;
}

Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(input_name);
std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, input_shape), "Cannot get shape");
std::vector<uint32_t> cast_output_shape(input_shape);
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, QnnQuantParamsWrapper(),
std::move(input_shape), std::move(gather_indices));
std::vector<uint32_t> cast_output_shape(indices_info.shape);
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, indices_info.qnn_data_type, QnnQuantParamsWrapper(),
std::move(indices_info.shape), std::move(qnn_indices_bytes));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");

if (!is_initializer_input && qnn_data_type == QNN_DATATYPE_INT_64) {
// Insert cast node int64 -> int32
if (qnn_data_type == QNN_DATATYPE_INT_64) {
// Add Cast node for indices
indices_input_name = input_name + "_ort_qnn_ep_cast";
QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32,
QnnQuantParamsWrapper(), std::move(cast_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
{input_name},
{indices_input_name},
{},
do_op_validation),
"Failed to add node.");
}
// Insert QNN Cast op to convert dynamic indices from int64 to int32.
std::string indices_input_name(input_name);
if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
assert(!indices_info.is_initializer);

indices_input_name = input_name + "_ort_qnn_ep_cast";
QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32,
QnnQuantParamsWrapper(), std::move(cast_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
"Cast",
{input_name},
{indices_input_name},
{},
do_op_validation),
"Failed to add node.");
}

input_names.push_back(indices_input_name);

return Status::OK();
}

Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "QNN EP: ", node_unit.OpType(), " operator must have two inputs");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names));

int64_t input0_axis_dim = 0;
ORT_RETURN_IF_ERROR(GetInpu0AxisDimValue(qnn_model_wrapper, node_unit, /*default_axis*/ 0, input0_axis_dim));

return ProcessIndicesInput(qnn_model_wrapper, inputs[1], input0_axis_dim, logger, input_names, do_op_validation);
}

Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
ORT_UNUSED_PARAMETER(logger);
const bool is_gather_elems = node_unit.OpType() == "GatherElements";

// Create QNN 'axis' parameter.
std::vector<std::string> param_tensor_names;
int32_t axis_value = 0;
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis_value));
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_GATHER_PARAM_AXIS, axis_qnn_scalar);
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(),
(is_gather_elems ? QNN_OP_GATHER_ELEMENTS_PARAM_AXIS : QNN_OP_GATHER_PARAM_AXIS),
axis_qnn_scalar);
param_tensor_names.push_back(axis_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));

if (is_gather_elems) {
return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
}

// if indicies is scalar shape, then need to add Reshape node
const auto& input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[0]);
const auto& indices_input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[1]);
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/quantization/operators/gather.py
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ def __init__(self, onnx_quantizer, onnx_node):

def quantize(self):
node = self.node
assert node.op_type == "Gather"
assert node.op_type == "Gather" or node.op_type == "GatherElements"

if self.quantizer.is_valid_quantize_weight(node.input[0]) or self.quantizer.force_quantize_no_input_check:
self.quantizer.quantize_activation_tensor(node.input[0])
1 change: 1 addition & 0 deletions onnxruntime/python/tools/quantization/registry.py
Original file line number Diff line number Diff line change
@@ -79,6 +79,7 @@
"MatMul": QDQMatMul,
"Split": QDQSplit,
"Gather": QDQGather,
"GatherElements": QDQGather,
"Where": QDQWhere,
"InstanceNormalization": QDQNormalization,
"LayerNormalization": QDQNormalization,
3 changes: 2 additions & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
@@ -827,7 +827,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
ORT_TSTR("sce_NCd1d2d3_sum_weight_high_ii_expanded"),
ORT_TSTR("sce_none_weights_log_prob_expanded"),
ORT_TSTR("sce_none_weights_expanded"),
ORT_TSTR("convtranspose_3d")};
ORT_TSTR("convtranspose_3d"),
ORT_TSTR("gather_elements_negative_indices")};

std::unordered_set<std::basic_string<ORTCHAR_T>> all_disabled_tests(std::begin(immutable_broken_tests), std::end(immutable_broken_tests));

Original file line number Diff line number Diff line change
@@ -42,9 +42,15 @@ void GetData(const std::vector<int64_t>& input_dims, const std::vector<int64_t>&
output_data.resize(output_size);
std::srand(static_cast<unsigned>(std::time(0)));
for (size_t i = 0; i < indices_size; ++i) {
#if defined(USE_QNN)
// Negative index not possible.
indices_data[i] =
static_cast<TIndex>(static_cast<int64_t>(std::rand()) % input_dims[axis]);
#else
// Negative index possible.
indices_data[i] =
static_cast<TIndex>((static_cast<int64_t>(std::rand()) % (input_dims[axis] * 2)) - input_dims[axis]);
#endif
}
for (size_t i = 0; i < output_size; ++i) {
int64_t input_offset = 0;
@@ -382,9 +388,10 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) {
// skip cuda as the cuda kernel won't throw the error message
// skip openvino which will not throw error message but will ensure no out-of-bound access
// skip TensorRT because it doesn't support out of bounds indices
// skip QNN because it doesn't support out of bounds indices
test.Run(OpTester::ExpectResult::kExpectFailure, "",
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider,
kTensorrtExecutionProvider, kDmlExecutionProvider});
kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider});
}

TEST(GatherElementsOpTest, BigIndices) {
Loading