Skip to content

Commit

Permalink
[QNN EP] Fix index-out-of-bounds bug in Slice builder when initialize…
Browse files Browse the repository at this point in the history
…r is shared (#17905)

### Description
There's an index-out-of-bounds bug that is triggered when a Slice
operator shares an initializer with another operator that is processed
first. In this case, QNN EP fails to properly initialize a `raw_starts`
(or `raw_ends`) vector, which is later indexed by a call to
`SliceOp::PrepareForComputeHelper()`.


### Motivation and Context
Fix bug that blocks #17764
  • Loading branch information
adrianlizarraga authored Oct 13, 2023
1 parent 9c65d55 commit 3b69d9b
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 109 deletions.
201 changes: 94 additions & 107 deletions onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "core/providers/qnn/builder/qnn_utils.h"
#include "core/providers/cpu/tensor/slice_helper.h"

#include "core/framework/tensorprotoutils.h"

#include "base_op_builder.h"

namespace onnxruntime {
Expand Down Expand Up @@ -37,16 +39,13 @@ class SliceOpBuilder : public BaseOpBuilder {
TensorShapeVector& raw_starts,
TensorShapeVector& raw_ends,
TensorShapeVector& raw_axes) const;
typedef struct {
int32_t begin, end, stride;
} Range;
mutable std::vector<Range> ranges_;
};

Status SliceOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
size_t input_count = node_unit.Inputs().size();
// Op set 9 only has 1 input with starts, ends, axes attribute
// Op set > 9, starts, ends, axes are from node input

// Opset < 10: Only has 1 data input. The starts, ends, and axes values are attributes.
// Opset >= 10: Everything is an input. The data, starts, and ends inputs are required.
if (input_count > 1) {
// Skip the first input. All other input need to be initializer
for (size_t i = 1; i < input_count; i++) {
Expand Down Expand Up @@ -75,6 +74,46 @@ void SliceOpBuilder::GetDataFromAttribute(const NodeUnit& node_unit,
}
}

// Gets the data from initializer inputs (e.g., starts, ends, axes, or steps) as a TensorShapeVector.
static Status GetInitializerInputData(const NodeUnitIODef& input, const QnnModelWrapper& qnn_model_wrapper,
TensorShapeVector& output) {
const auto& input_name = input.node_arg.Name();
const bool is_initializer = qnn_model_wrapper.IsInitializerInput(input_name);
ORT_RETURN_IF_NOT(is_initializer, "Expected input ", input_name.c_str(), " to be an initializer.");
gsl::not_null<const ONNX_NAMESPACE::TensorProto*> initializer_proto = qnn_model_wrapper
.GetInitializerTensors()
.at(input_name);
ORT_RETURN_IF_NOT(initializer_proto->has_data_type(), "Expected initializer ", input_name.c_str(),
" to have a proto data type.");

// Create empty Tensor.
const auto* dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer_proto->data_type())->GetElementType();
TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(*initializer_proto);
Tensor tensor(dtype, shape, std::make_shared<CPUAllocator>());

// Deserialize initializer into Tensor.
onnxruntime::PathString model_path = qnn_model_wrapper.GetGraphViewer().ModelPath().ToPathString();
const ORTCHAR_T* model_path_str = model_path.empty() ? nullptr : model_path.c_str();
ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), model_path_str,
*initializer_proto, tensor));

Status status;

// Copy Tensor of int32_t or int64_t elems into output (int64_ts).
if (tensor.IsDataType<int64_t>()) {
gsl::span<const int64_t> tensor_elems = tensor.DataAsSpan<int64_t>();
output.insert(output.end(), tensor_elems.begin(), tensor_elems.end());
} else if (tensor.IsDataType<int32_t>()) {
gsl::span<const int32_t> tensor_elems = tensor.DataAsSpan<int32_t>();
output.insert(output.end(), tensor_elems.begin(), tensor_elems.end());
} else {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type ", DataTypeImpl::ToString(dtype),
" is not supported for Slice initializer input ", input.node_arg.Name().c_str());
}

return status;
}

// Note: For ONNX Slice operation the expected number of inputs is between 3 and 5
Status SliceOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
Expand All @@ -84,123 +123,71 @@ Status SliceOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
if (do_op_validation) {
ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit));
}
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32;

// Only need to add input 0. The other inputs (if any) contain static data that is passed to QNN APIs
// as static parameters.
return ProcessInput(qnn_model_wrapper, node_unit.Inputs()[0], logger, input_names);
}

Status SliceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
// Extract starts, ends, axes, and steps data from attributes (opset < 10) or initializer inputs (opset >= 10).
TensorShapeVector raw_starts;
TensorShapeVector raw_ends;
TensorShapeVector raw_axes;
TensorShapeVector raw_steps;
std::vector<uint32_t> input0_shape;

auto inputs = node_unit.Inputs();
auto input_count = inputs.size();
// Opset 9, only 1 input, starts, ends, axes are in attribute
if (1 == input_count) {
const auto& inputs = node_unit.Inputs();
const size_t input_count = inputs.size();

// Opset 9 only has 1 input. The starts, ends, axes values are attributes.
if (node_unit.SinceVersion() < 10) {
GetDataFromAttribute(node_unit, raw_starts, raw_ends, raw_axes);
}
} else {
constexpr size_t starts_index = 1;
constexpr size_t ends_index = 2;
constexpr size_t axes_index = 3;
constexpr size_t steps_index = 4;

for (size_t input_i = 0; input_i < input_count; ++input_i) {
auto& input_name = inputs[input_i].node_arg.Name();
if (input_name.empty()) {
// Ignore unspecified/unused optional input
continue;
}
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
LOGS(logger, VERBOSE) << "Tensor already added or the input is not named, skip it: " << input_name;
input_names.push_back(input_name);
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[input_i].node_arg, input0_shape), "Cannot get shape");
continue;
// Starts input (required).
ORT_RETURN_IF_ERROR(GetInitializerInputData(inputs[starts_index], qnn_model_wrapper, raw_starts));

// Ends input (required).
ORT_RETURN_IF_ERROR(GetInitializerInputData(inputs[ends_index], qnn_model_wrapper, raw_ends));

// Axes input (optional).
if (input_count > axes_index && !inputs[axes_index].node_arg.Name().empty()) {
ORT_RETURN_IF_ERROR(GetInitializerInputData(inputs[axes_index], qnn_model_wrapper, raw_axes));
}

bool is_quantized_tensor = inputs[input_i].quant_param.has_value();
const auto* type_proto = inputs[input_i].node_arg.TypeAsProto();
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, type_proto, qnn_data_type));

std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[input_i].node_arg, input_shape), "Cannot get shape");

Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;
utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor);
ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(inputs[input_i].quant_param,
quantize_param.scaleOffsetEncoding.scale,
quantize_param.scaleOffsetEncoding.offset),
"Cannot get quantization parameter");

std::vector<uint8_t> unpacked_tensor;
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);
if (is_initializer_input) {
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor));
size_t tensor_byte_size = unpacked_tensor.size();
const auto data_type = input_tensor->data_type();
TensorShapeVector data;
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
const int64_t* tensor_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
size_t size = tensor_byte_size / sizeof(int64_t);
data.insert(data.end(), tensor_data, tensor_data + size);
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
const int32_t* tensor_data = reinterpret_cast<const int32_t*>(unpacked_tensor.data());
size_t size = tensor_byte_size / sizeof(int32_t);
data.insert(data.end(), tensor_data, tensor_data + size);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Data type for starts and ends inputs' is not supported in this build. Got ",
data_type);
}
if (input_i == 0) {
// Do nothing!
} else if (input_i == 1) {
// Starts
raw_starts = data;
continue;
} else if (input_i == 2) {
// Ends
raw_ends = data;
continue;
} else if (input_i == 3) {
// Axes
raw_axes = data;
continue;
} else if (input_i == 4) {
// Steps
raw_steps = data;
continue;
}
// Steps input (optional).
if (input_count > steps_index && !inputs[steps_index].node_arg.Name().empty()) {
ORT_RETURN_IF_ERROR(GetInitializerInputData(inputs[steps_index], qnn_model_wrapper, raw_steps));
}
input0_shape = input_shape;

input_names.push_back(input_name);
Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name);
Qnn_QuantizeParams_t quantize_params = QNN_QUANTIZE_PARAMS_INIT;
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_params,
std::move(input_shape), std::move(unpacked_tensor));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
}

std::vector<uint32_t> input0_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input0_shape),
"Cannot get shape for Slice input 0.");

TensorShapeVector input_dimensions(input0_shape.cbegin(), input0_shape.cend());
onnxruntime::SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions);
ORT_RETURN_IF_ERROR(
SliceOp::PrepareForComputeHelper(raw_starts, raw_ends, raw_axes, raw_steps, compute_metadata));
ranges_.clear();
for (size_t i = 0; i < input_dimensions.size(); i++) {
auto start = static_cast<int32_t>(compute_metadata.starts_[i]);
auto end = static_cast<int32_t>(compute_metadata.ends_[i]);
auto step = static_cast<int32_t>(compute_metadata.steps_[i]);
ranges_.push_back(Range({start, end, step}));
}
return Status::OK();
}
ORT_RETURN_IF_ERROR(SliceOp::PrepareForComputeHelper(raw_starts, raw_ends, raw_axes, raw_steps, compute_metadata));

Status SliceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
std::vector<uint32_t> ranges_dims{static_cast<uint32_t>(ranges_.size()), 3};
const size_t input_rank = input_dimensions.size();
std::vector<uint32_t> ranges_dims{static_cast<uint32_t>(input_rank), 3};
std::vector<uint32_t> ranges_data;
for (auto range : ranges_) {
ranges_data.push_back(static_cast<uint32_t>(range.begin));
ranges_data.push_back(static_cast<uint32_t>(range.end));
ranges_data.push_back(static_cast<uint32_t>(range.stride));
ranges_data.reserve(input_rank);

for (size_t i = 0; i < input_rank; i++) {
ranges_data.push_back(static_cast<uint32_t>(compute_metadata.starts_[i]));
ranges_data.push_back(static_cast<uint32_t>(compute_metadata.ends_[i]));
ranges_data.push_back(static_cast<uint32_t>(compute_metadata.steps_[i]));
}

QnnParamWrapper ranges_paramwrapper(node_unit.Index(),
node_unit.Name(),
QNN_OP_STRIDED_SLICE_PARAM_RANGES,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class QnnModelWrapper {

QnnBackendType GetQnnBackendType() { return qnn_backend_type_; }

const GraphViewer& GetGraphViewer() const { return graph_viewer_; }

private:
bool CreateQnnInputOutputTensors(const std::string& qnn_node_name,
const std::vector<std::string>& names,
Expand Down
82 changes: 80 additions & 2 deletions onnxruntime/test/providers/qnn/slice_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,48 @@

namespace onnxruntime {
namespace test {

// Test for "index-out-of-bounds" bug that occurred when a Slice operator
// shared one of its initializer inputs with another op that was processed by QNN EP first.
TEST_F(QnnCPUBackendTests, Slice_SharedInitializersBugFix) {
// Model with an Add that processes a shared initializer before Slice is processed.
GetTestModelFn model_fn = [](ModelTestBuilder& builder) {
NodeArg* input0 = builder.MakeInput<int32_t>({2, 2}, {1, 2, 3, 4});

// Initializers
NodeArg* starts_input = builder.Make1DInitializer<int32_t>({1, 0}); // Shared by Add
NodeArg* ends_input = builder.Make1DInitializer<int32_t>({2, 2});
NodeArg* axes_input = builder.Make1DInitializer<int32_t>({0, 1});
NodeArg* steps_input = builder.Make1DInitializer<int32_t>({1, 1});

// Add input0 with a shared initializer.
NodeArg* add_output = builder.MakeIntermediate();
builder.AddNode("Add", {input0, starts_input}, {add_output});

// Cast Add's output to float.
NodeArg* cast_output = builder.MakeIntermediate();
Node& cast_node = builder.AddNode("Cast", {add_output}, {cast_output});
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT));

// Slice Cast's output
NodeArg* slice0_out = builder.MakeOutput();
builder.AddNode("Slice", {cast_output, starts_input, ends_input, axes_input, steps_input}, {slice0_out});
};

ProviderOptions provider_options;

#if defined(_WIN32)
provider_options["backend_path"] = "QnnCpu.dll";
#else
provider_options["backend_path"] = "libQnnCpu.so";
#endif

RunQnnModelTest(model_fn,
provider_options,
13, // opset
ExpectedEPNodeAssignment::All);
}

#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

/**
Expand All @@ -26,14 +68,16 @@ namespace test {
* \param axes_def The axes input's definition.
* \param steps_def The steps input's definition.
* \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None).
* \param use_contrib_qdq Force Q/DQ ops to use the com.microsoft domain (enable 16-bit).
*/
template <typename QuantType = uint8_t>
static void RunSliceQDQTest(const TestInputDef<float>& data_def,
const TestInputDef<int64_t>& starts_def,
const TestInputDef<int64_t>& ends_def,
const TestInputDef<int64_t>& axes_def,
const TestInputDef<int64_t>& steps_def,
ExpectedEPNodeAssignment expected_ep_assignment) {
ExpectedEPNodeAssignment expected_ep_assignment,
bool use_contrib_qdq = false) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
Expand All @@ -45,7 +89,8 @@ static void RunSliceQDQTest(const TestInputDef<float>& data_def,
const std::vector<TestInputDef<int64_t>> int64_inputs = {starts_def, ends_def, axes_def, steps_def};

TestQDQModelAccuracy(BuildOpTestCase<float, int64_t>("Slice", f32_inputs, int64_inputs, {}),
BuildQDQOpTestCase<QuantType, int64_t>("Slice", f32_inputs, int64_inputs, {}),
BuildQDQOpTestCase<QuantType, int64_t>("Slice", f32_inputs, int64_inputs, {}, kOnnxDomain,
use_contrib_qdq),
provider_options,
18,
expected_ep_assignment);
Expand Down Expand Up @@ -123,6 +168,39 @@ TEST_F(QnnHTPBackendTests, SliceInt32OnHTP) {
ExpectedEPNodeAssignment::All);
}

// Test 8-bit QDQ Slice with more than 1 axis.
TEST_F(QnnHTPBackendTests, SliceU8_MultAxes) {
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
RunSliceQDQTest<uint8_t>(TestInputDef<float>({2, 4}, false, input_data),
TestInputDef<int64_t>({2}, true, {1, 0}), // starts
TestInputDef<int64_t>({2}, true, {2, 3}), // ends
TestInputDef<int64_t>({2}, true, {0, 1}), // axes
TestInputDef<int64_t>({2}, true, {1, 2}), // steps
ExpectedEPNodeAssignment::All);
}

// Test 16-bit QDQ Slice with more than 1 axis.
TEST_F(QnnHTPBackendTests, SliceU16_MultAxes) {
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
RunSliceQDQTest<uint16_t>(TestInputDef<float>({2, 4}, false, input_data),
TestInputDef<int64_t>({2}, true, {1, 0}), // starts
TestInputDef<int64_t>({2}, true, {2, 3}), // ends
TestInputDef<int64_t>({2}, true, {0, 1}), // axes
TestInputDef<int64_t>({2}, true, {1, 2}), // steps
ExpectedEPNodeAssignment::All,
true); // Use com.microsoft Q/DQ ops for 16-bit
}

// Test 8-bit QDQ Slice with more than 1 axis and an end value that exceeds the associated dimension size.
TEST_F(QnnHTPBackendTests, SliceU8_MultAxes_LargeEnd) {
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
RunSliceQDQTest<uint8_t>(TestInputDef<float>({2, 4}, false, input_data),
TestInputDef<int64_t>({2}, true, {0, 1}), // starts
TestInputDef<int64_t>({2}, true, {-1, 1000}), // ends
TestInputDef<int64_t>({2}, true, {0, 1}), // axes
TestInputDef<int64_t>({2}, true, {1, 1}), // steps
ExpectedEPNodeAssignment::All);
}
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

} // namespace test
Expand Down

0 comments on commit 3b69d9b

Please sign in to comment.