From 1567536eeb7670478678512bcd77be44159a4088 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 12 Oct 2023 12:45:47 -0700 Subject: [PATCH] Address review comments --- .../qnn/builder/opbuilder/slice_op_builder.cc | 28 +++++++++++-------- .../core/providers/qnn/builder/qnn_utils.h | 8 ++++++ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc index db0c6f9517f69..90c82eeac7255 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc @@ -78,28 +78,29 @@ static Status GetInitializerInputData(const NodeUnitIODef& input, const QnnModel OnnxInputInfo input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(input, input_info)); ORT_RETURN_IF_NOT(input_info.is_initializer, - "QNN requires the starts, ends, axes, and steps inputs to " - "be initializers"); + "QNN requires the starts, ends, axes, and steps inputs to be initializers"); std::vector initializer_bytes; + + // Note: UnpackInitializerData() uses ORT's protobuf utilities, which ensure that the initializer bytes are + // contiguous, aligned, and in the appropriate endianness. This is necessary to be able to reinterpret bytes + // as an array of larger elements. ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, initializer_bytes)); - size_t tensor_byte_size = initializer_bytes.size(); const auto data_type = input_info.initializer_tensor->data_type(); - Status status; switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - const int64_t* tensor_data = reinterpret_cast(initializer_bytes.data()); - size_t size = tensor_byte_size / sizeof(int64_t); - output.insert(output.end(), tensor_data, tensor_data + size); + gsl::span elements = qnn::utils::ReinterpretBytesAsSpan(initializer_bytes.data(), + initializer_bytes.size()); + output.insert(output.end(), elements.begin(), elements.end()); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - const int32_t* tensor_data = reinterpret_cast(initializer_bytes.data()); - size_t size = tensor_byte_size / sizeof(int32_t); - output.insert(output.end(), tensor_data, tensor_data + size); + gsl::span elements = qnn::utils::ReinterpretBytesAsSpan(initializer_bytes.data(), + initializer_bytes.size()); + output.insert(output.end(), elements.begin(), elements.end()); break; } default: @@ -174,9 +175,12 @@ Status SliceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr onnxruntime::SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); ORT_RETURN_IF_ERROR(SliceOp::PrepareForComputeHelper(raw_starts, raw_ends, raw_axes, raw_steps, compute_metadata)); - std::vector ranges_dims{static_cast(input_dimensions.size()), 3}; + const size_t input_rank = input_dimensions.size(); + std::vector ranges_dims{static_cast(input_rank), 3}; std::vector ranges_data; - for (size_t i = 0; i < input_dimensions.size(); i++) { + ranges_data.reserve(input_rank); + + for (size_t i = 0; i < input_rank; i++) { ranges_data.push_back(static_cast(compute_metadata.starts_[i])); ranges_data.push_back(static_cast(compute_metadata.ends_[i])); ranges_data.push_back(static_cast(compute_metadata.steps_[i])); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index a54e0c8276e71..f71d63ba7e6f6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -8,12 +8,20 @@ #include #include #include +#include namespace onnxruntime { namespace qnn { class QnnOpConfigWrapper; namespace utils { + +// Reinterprets an array of contiguous bytes in the target's endianness to a span of elements. +template +inline gsl::span ReinterpretBytesAsSpan(const uint8_t* data, size_t num_bytes) { + return gsl::span(reinterpret_cast(data), num_bytes / sizeof(T)); +} + size_t GetElementSizeByType(const Qnn_DataType_t& data_type); size_t GetElementSizeByType(ONNXTensorElementDataType elem_type);