Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN EP] Add support for GatherElements #15966

Merged
merged 21 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bdbf9f8
Add support for GatherElements
adrianlizarraga May 11, 2023
23f90f2
Add tests for GatherElements
adrianlizarraga May 11, 2023
2d9be4a
Add support for GatherND and corresponding unit tests
adrianlizarraga May 16, 2023
e29f309
Merge latest commits from main
adrianlizarraga May 16, 2023
daf2905
Revert use of QNN operator macros until we can add more unit tests
adrianlizarraga May 16, 2023
6af2b56
Wrap long lines; improve zero point for types other than uint8_t
adrianlizarraga May 16, 2023
617c73e
Reveal another bug with GatherElements via a unit test
adrianlizarraga May 17, 2023
c9d1b37
Merge main and fix conflicts just to see where we stand. Need to upda…
adrianlizarraga Sep 22, 2023
0bd42bd
Merge with latest main branch
adrianlizarraga Aug 9, 2024
21b21dd
Update code to support negative static indices
adrianlizarraga Aug 9, 2024
3728b74
Remove GatherND traces
adrianlizarraga Aug 10, 2024
e183bd3
Update tests
adrianlizarraga Aug 12, 2024
1d327fd
Remove ORT_UNUSED_PARAM
adrianlizarraga Aug 12, 2024
d5264c6
Merge branch 'main' into adrianl/qnn-support-gatherelems-gathernd
adrianlizarraga Aug 12, 2024
20be096
Support QDQ GatherElements in quantization tool
adrianlizarraga Aug 12, 2024
32d0583
Disable ONNX test with negative gather elem indices for QNN
adrianlizarraga Aug 12, 2024
adad6d7
Add unit test for inaccurate GatherElems configuration with 2M indices
adrianlizarraga Aug 12, 2024
e934d02
Add py quantization tool unittest for QDQ GatherElements
adrianlizarraga Aug 12, 2024
0cc6285
fix warning as error
adrianlizarraga Aug 12, 2024
dbd1fb9
Merge branch 'main' into adrianl/qnn-support-gatherelems-gathernd
adrianlizarraga Aug 13, 2024
19ea8e1
Flip ternary values
adrianlizarraga Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops
/* static methods to return different operator's OpVersionMap */
static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
return {{"Gather", {}},
{"GatherElements", {}},
{"GatherND", {}},
{"Reshape", {}},
{"Flatten", {}},
{"Transpose", {}},
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ OpBuilderRegistrations::OpBuilderRegistrations() {

{
CreateGatherOpBuilder("Gather", *this);
CreateGatherOpBuilder("GatherElements", *this);
CreateGatherOpBuilder("GatherND", *this);
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Status BaseOpBuilder::AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

bool BaseOpBuilder::OnnxDataTypeToQnnDataType(const int32_t onnx_data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized) const {
bool OnnxDataTypeToQnnDataType(const int32_t onnx_data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized) {
const std::unordered_map<int32_t, Qnn_DataType_t> onnx_to_qnn_data_type = {
{ONNX_NAMESPACE::TensorProto_DataType_INT8, QNN_DATATYPE_INT_8},
{ONNX_NAMESPACE::TensorProto_DataType_INT16, QNN_DATATYPE_INT_16},
Expand Down Expand Up @@ -312,7 +312,7 @@ Status BaseOpBuilder::ProcessAxisAttribute(const QnnModelWrapper& qnn_model_wrap
return Status::OK();
}

Qnn_TensorType_t BaseOpBuilder::GetInputTensorType(const QnnModelWrapper& qnn_model_wrapper, const std::string& input_name) const {
Qnn_TensorType_t GetInputTensorType(const QnnModelWrapper& qnn_model_wrapper, const std::string& input_name) {
if (qnn_model_wrapper.IsInitializerInput(input_name)) {
return QNN_TENSOR_TYPE_STATIC;
} else if (qnn_model_wrapper.IsGraphInput(input_name)) {
Expand Down
37 changes: 20 additions & 17 deletions onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,6 @@ class BaseOpBuilder : public IOpBuilder {
bool is_quantized_model,
std::vector<std::string>& input_names) const ORT_MUST_USE_RESULT;

bool OnnxDataTypeToQnnDataType(const int32_t data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized = false) const;

Status GetQnnDataType(const bool is_quantized_node, const ONNX_NAMESPACE::TypeProto* type_proto,
Qnn_DataType_t& tensor_data_type) const {
if (!type_proto || !type_proto->tensor_type().has_elem_type()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The tensor doesn't have elem_type.");
}

int32_t onnx_data_type = type_proto->tensor_type().elem_type();
ORT_RETURN_IF_NOT(OnnxDataTypeToQnnDataType(onnx_data_type, tensor_data_type, is_quantized_node),
"Failed to map Onnx data type to Qnn data type!");

return Status::OK();
}

adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
const std::string& GetNodeName(const NodeUnit& node_unit) const {
const std::string& node_name(node_unit.Name());
if (node_name.empty()) {
Expand All @@ -108,7 +93,9 @@ class BaseOpBuilder : public IOpBuilder {
{"Equal", "ElementWiseEqual"},
{"Exp", "ElementWiseExp"},
{"Floor", "ElementWiseFloor"},
{"Gather", "Gather"},
{"Gather", QNN_OP_GATHER},
{"GatherElements", QNN_OP_GATHER_ELEMENTS},
{"GatherND", QNN_OP_GATHER_ND},
{"Greater", "ElementWiseGreater"},
{"GreaterOrEqual", "ElementWiseGreaterEqual"},
{"Less", "ElementWiseLess"},
Expand Down Expand Up @@ -268,7 +255,6 @@ class BaseOpBuilder : public IOpBuilder {
const NodeUnit& node_unit,
Qnn_Scalar_t& axis_qnn_scalar,
int32_t& default_axis_value) const;
Qnn_TensorType_t GetInputTensorType(const QnnModelWrapper& qnn_model_wrapper, const std::string& input_name) const;

size_t GetInputCountQnnRequired(const NodeUnit& node_unit) const {
auto input_output_cout = GetInputOutputCountQnnRequired(node_unit.OpType());
Expand Down Expand Up @@ -304,6 +290,23 @@ class BaseOpBuilder : public IOpBuilder {
const std::vector<size_t> cnhw2hwcn_perm{2, 3, 0, 1};
};

Qnn_TensorType_t GetInputTensorType(const QnnModelWrapper& qnn_model_wrapper, const std::string& input_name);

bool OnnxDataTypeToQnnDataType(const int32_t data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized = false);

inline Status GetQnnDataType(const bool is_quantized_node, const ONNX_NAMESPACE::TypeProto* type_proto,
Qnn_DataType_t& tensor_data_type) {
if (!type_proto || !type_proto->tensor_type().has_elem_type()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The tensor doesn't have elem_type.");
}

int32_t onnx_data_type = type_proto->tensor_type().elem_type();
ORT_RETURN_IF_NOT(OnnxDataTypeToQnnDataType(onnx_data_type, tensor_data_type, is_quantized_node),
"Failed to map Onnx data type to Qnn data type!");

return Status::OK();
}

// Type that holds information about an ONNX attribute.
template <typename ValType>
struct OnnxAttrInfo {
Expand Down
202 changes: 174 additions & 28 deletions onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
namespace onnxruntime {
namespace qnn {

// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes
class GatherOpBuilder : public BaseOpBuilder {
public:
GatherOpBuilder() : BaseOpBuilder("GatherOpBuilder") {}
Expand All @@ -35,31 +34,90 @@ class GatherOpBuilder : public BaseOpBuilder {
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model,
std::vector<std::string>& input_names,
bool do_op_validation) const {
ORT_UNUSED_PARAMETER(do_op_validation);
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "Gather should has 2 inputs at least!");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, is_quantized_model, input_names));
class GatherElementsOpBuilder : public BaseOpBuilder {
public:
GatherElementsOpBuilder() : BaseOpBuilder("GatherElementsOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GatherElementsOpBuilder);

protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model,
std::vector<std::string>& input_names,
bool do_op_validation) const override ORT_MUST_USE_RESULT;

Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool is_quantized_model,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

class GatherNDOpBuilder : public BaseOpBuilder {
public:
GatherNDOpBuilder() : BaseOpBuilder("GatherNDOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GatherNDOpBuilder);

protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model,
std::vector<std::string>& input_names,
bool do_op_validation) const override ORT_MUST_USE_RESULT;

Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool is_quantized_model,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

// Converts int64 indices to another integer type (typically int32 or uint32).
// The input and output are both represented as byte arrays.
template <typename T>
static void ConvertInt64IndicesBytes(const std::vector<uint8_t>& onnx_bytes, std::vector<uint8_t>& qnn_bytes) {
const size_t num_elems = onnx_bytes.size() / sizeof(uint64_t);
gsl::span<const uint64_t> onnx_indices{reinterpret_cast<const uint64_t*>(onnx_bytes.data()), num_elems};

// Process indices
const auto& input_name = inputs[1].node_arg.Name();
qnn_bytes.resize(num_elems * sizeof(T));
T* qnn_indices_ptr = reinterpret_cast<T*>(qnn_bytes.data());

std::transform(onnx_indices.begin(), onnx_indices.end(), qnn_indices_ptr,
[](int64_t index) { return SafeInt<T>(index); });
}

// Processes the indices input to Gather operators.
//
// Gather ops on the QNN CPU backend require int32 indices, so this function will either add a Cast operator
// to dynamic indices or transform static indices to int32/uint32.
//
// The HTP backend does not support int64, so this function returns an error status if dynamic indices are of
// type int64. If the indices are static, then this function will convert them to int32/uint32.
static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
const NodeUnitIODef& indices_input,
bool int32_type_is_signed,
const logging::Logger& logger,
bool is_quantized_model,
std::vector<std::string>& input_names,
bool do_op_validation) {
Qnn_DataType_t desired_data_type = int32_type_is_signed ? QNN_DATATYPE_INT_32 : QNN_DATATYPE_UINT_32;

const auto& input_name = indices_input.node_arg.Name();
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
input_names.push_back(input_name);
return Status::OK();
}

std::string indices_input_name(input_name);
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_INT_32;
const auto* type_proto = inputs[1].node_arg.TypeAsProto();
Qnn_DataType_t qnn_data_type = desired_data_type;
const auto* type_proto = indices_input.node_arg.TypeAsProto();
ORT_RETURN_IF_ERROR(GetQnnDataType(false, type_proto, qnn_data_type));

std::vector<uint8_t> unpacked_tensor;
std::vector<uint8_t> gather_indices;
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);

Expand All @@ -68,38 +126,39 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,

if (is_initializer_input) {
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
std::vector<uint8_t> unpacked_tensor;

ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));

if (qnn_data_type == QNN_DATATYPE_INT_64) {
// Convert initializer from int64 to int32
size_t size = unpacked_tensor.size() / sizeof(int64_t);
const int64_t* gather_indices_int64 = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
gather_indices.resize(size * sizeof(int32_t));
int32_t* gather_indices_int32 = reinterpret_cast<int32_t*>(gather_indices.data());
std::transform(gather_indices_int64, gather_indices_int64 + size, gather_indices_int32,
[](int64_t item) { return SafeInt<uint32_t>(item); });
if (desired_data_type == QNN_DATATYPE_INT_32) {
ConvertInt64IndicesBytes<int32_t>(unpacked_tensor, gather_indices);
} else {
ConvertInt64IndicesBytes<uint32_t>(unpacked_tensor, gather_indices);
}
} else {
gather_indices = std::move(unpacked_tensor);
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
}
qnn_data_type = QNN_DATATYPE_INT_32;
qnn_data_type = desired_data_type;
}

// Even for Quantized model, Gather indices use int32 without quantization
Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;

Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name);
std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, input_shape), "Cannot get shape");
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(indices_input.node_arg, input_shape), "Cannot get shape");
std::vector<uint32_t> cast_output_shape(input_shape);
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param,
std::move(input_shape), std::move(gather_indices));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");

if (!is_initializer_input && qnn_data_type == QNN_DATATYPE_INT_64) {
// Insert cast node int64 -> int32
// Insert cast node int64 -> int32/uint32
if (qnn_data_type == QNN_DATATYPE_INT_64) {
// Add Cast node for indices
indices_input_name = input_name + "_cast";
QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32, quantize_param,
QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, desired_data_type, quantize_param,
std::move(cast_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name,
Expand All @@ -118,6 +177,45 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "QNN EP: Gather operator must have two inputs");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, is_quantized_model, input_names));

return ProcessIndicesInput(qnn_model_wrapper, inputs[1], true, logger, is_quantized_model, input_names, do_op_validation);
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
}

Status GatherElementsOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "QNN EP: GatherElements operator must have two inputs");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, is_quantized_model, input_names));

return ProcessIndicesInput(qnn_model_wrapper, inputs[1], false, logger, is_quantized_model, input_names, do_op_validation);
}

Status GatherNDOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
ORT_RETURN_IF(inputs.size() != 2, "QNN EP: GatherND operator must have two inputs");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, is_quantized_model, input_names));

return ProcessIndicesInput(qnn_model_wrapper, inputs[1], false, logger, is_quantized_model, input_names, do_op_validation);
}

Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
Expand Down Expand Up @@ -212,8 +310,56 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
return Status::OK();
}

Status GatherElementsOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool is_quantized_model,
bool do_op_validation) const {
std::vector<std::string> param_tensor_names;
int32_t axis_value = 0;
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis_value));
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), qnn_def::axis, axis_qnn_scalar);
param_tensor_names.push_back(axis_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));

return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names),
logger, is_quantized_model, do_op_validation);
}

Status GatherNDOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool is_quantized_model,
bool do_op_validation) const {
std::vector<std::string> param_tensor_names;
NodeAttrHelper node_attr_helper(node_unit);
int32_t onnx_batch_dims = node_attr_helper.Get("batch_dims", 0);

Qnn_Scalar_t qnn_batch_dims_scalar = QNN_SCALAR_INIT;
qnn_batch_dims_scalar.dataType = QNN_DATATYPE_UINT_32;
qnn_batch_dims_scalar.uint32Value = SafeInt<uint32_t>(onnx_batch_dims);

QnnParamWrapper batch_dims_param(node_unit.Index(), node_unit.Name(), QNN_OP_GATHER_ND_PARAM_BATCH_DIMS,
qnn_batch_dims_scalar);

param_tensor_names.push_back(batch_dims_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(batch_dims_param));

return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names),
logger, is_quantized_model, do_op_validation);
}

void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.AddOpBuilder(op_type, std::make_unique<GatherOpBuilder>());
if (op_type == "Gather") {
op_registrations.AddOpBuilder(op_type, std::make_unique<GatherOpBuilder>());
} else if (op_type == "GatherElements") {
op_registrations.AddOpBuilder(op_type, std::make_unique<GatherElementsOpBuilder>());
} else if (op_type == "GatherND") {
op_registrations.AddOpBuilder(op_type, std::make_unique<GatherNDOpBuilder>());
}
}

} // namespace qnn
Expand Down
Loading