Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN EP] Handle rank 3 InstanceNormalization with N != 1 #17897

Merged
merged 5 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class InstanceNormOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;

protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const override ORT_MUST_USE_RESULT;

Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
Expand Down Expand Up @@ -81,6 +87,73 @@ Status InstanceNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

Status InstanceNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
assert(inputs.size() == 3);
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved

OnnxInputInfo input0_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[0], input0_info));

// HTP backend can only handle rank 3 inputs if the batch size is 1. If the batch size is not 1,
// QNN EP must reshape the input and output to (N, 1, W, C) and process the InstanceNorm as rank 4.
if (input0_info.shape.size() != 3 || input0_info.shape[0] == 1) {
return BaseOpBuilder::ProcessInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation);
}

//
// Input 0 is rank 3 with batch size != 1. Must reshape the input to rank 4.
//

{
const std::string& input0_name = inputs[0].node_arg.Name();
const std::string op_input0_name = input0_info.is_initializer ? input0_name
: input0_name + "_ort_qnn_ep_reshape";
input_names.push_back(op_input0_name);

std::vector<uint8_t> initializer_data;
if (input0_info.is_initializer) {
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input0_info.initializer_tensor, initializer_data));
}

assert(node_unit.Domain() == kMSInternalNHWCDomain);
std::vector<uint32_t> op_shape = {
input0_info.shape[0], // N
1, // Height == 1
input0_info.shape[1], // Width
input0_info.shape[2] // Channels
};

if (!input0_info.is_initializer) {
// Add Reshape node to transform 1D input to 2D (i.e., set height to 1).
// We don't need to do this for initializers, because the number of elements does not change. We can just
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
// modify the shape dimensions.
bool is_graph_input = qnn_model_wrapper.IsGraphInput(input0_name);
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(input0_name,
op_input0_name,
input0_info.shape,
op_shape,
input0_info.qnn_data_type,
input0_info.quant_param,
do_op_validation,
is_graph_input));
}

Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, op_input0_name);
QnnTensorWrapper input_tensorwrapper(op_input0_name, tensor_type, input0_info.qnn_data_type, input0_info.quant_param,
std::move(op_shape), std::move(initializer_data));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
}

ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[1], logger, input_names)); // Scale
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[2], logger, input_names)); // Bias

return Status::OK();
}

Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
Expand All @@ -100,11 +173,60 @@ Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_m
param_tensor_names.push_back(epsilon_param_wrapper.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(epsilon_param_wrapper));

ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(node_unit.OpType())));
const auto& outputs = node_unit.Outputs();
assert(outputs.size() == 1);

OnnxInputInfo output_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(outputs[0], output_info));

// HTP backend can only handle rank 3 inputs/outputs if the batch size is 1. If the batch size is not 1,
// QNN EP must reshape the input and output to (N, 1, W, C) and process the InstanceNorm as rank 4.
if (output_info.shape.size() != 3 || output_info.shape[0] == 1) {
return ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
}

//
// The output is meant to be rank 3 with batch size != 1. Must create a QNN InstanceNorm op with a rank 4 output
// that is then reshaped to rank 3 again.
//

const std::string& orig_output_name = outputs[0].node_arg.Name();
std::string op_output_name = orig_output_name + "_ort_qnn_ep_reshape";

assert(node_unit.Domain() == kMSInternalNHWCDomain);
std::vector<uint32_t> op_output_shape = {
output_info.shape[0], // N
1, // H == 1
output_info.shape[1], // W
output_info.shape[2], // C
};

QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type,
output_info.quant_param, std::vector<uint32_t>(op_output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit),
QNN_OP_PACKAGE_NAME_QTI_AISW,
GetQnnOpType(node_unit.OpType()),
std::move(input_names),
{op_output_name},
std::move(param_tensor_names)),
"Failed to add node.");

const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name);

// Add Reshape to convert QNN InstanceNorm output back to rank 3 (as expected by the rest of the ONNX graph).
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(op_output_name,
orig_output_name,
op_output_shape,
output_info.shape,
output_info.qnn_data_type,
output_info.quant_param,
do_op_validation,
false,
is_graph_output));
return Status::OK();
}

Expand Down
93 changes: 82 additions & 11 deletions onnxruntime/test/providers/qnn/instance_norm_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,26 @@ template <typename QuantType>
static GetTestQDQModelFn<QuantType> BuildQDQInstanceNormTestCase(const TestInputDef<float>& input_def,
const TestInputDef<float>& scale_def,
const TestInputDef<float>& bias_def,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs) {
return [input_def, scale_def, bias_def, attrs](ModelTestBuilder& builder,
std::vector<QuantParams<QuantType>>& output_qparams) {
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
bool use_contrib_qdq = false) {
return [input_def, scale_def, bias_def, attrs,
use_contrib_qdq](ModelTestBuilder& builder,
std::vector<QuantParams<QuantType>>& output_qparams) {
// input => Q => DQ =>
NodeArg* input = MakeTestInput(builder, input_def);
QuantParams<QuantType> input_qparams = GetTestInputQuantParams<QuantType>(input_def);
NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point);
NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point,
use_contrib_qdq);

// scale => Q => DQ =>
NodeArg* scale = MakeTestInput(builder, scale_def);
QuantParams<QuantType> scale_qparams = GetTestInputQuantParams<QuantType>(scale_def);
NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point);
NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point,
use_contrib_qdq);

// bias (as int32) => DQ =>
NodeArg* bias_qdq = MakeTestQDQBiasInput(builder, bias_def, input_qparams.scale * scale_qparams.scale);
NodeArg* bias_qdq = MakeTestQDQBiasInput(builder, bias_def, input_qparams.scale * scale_qparams.scale,
use_contrib_qdq);

// InstanceNormalization operator.
auto* instance_norm_output = builder.MakeIntermediate();
Expand All @@ -46,7 +51,8 @@ static GetTestQDQModelFn<QuantType> BuildQDQInstanceNormTestCase(const TestInput
}

// Add instance_norm_output -> Q -> output_u8
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, instance_norm_output, output_qparams[0].scale, output_qparams[0].zero_point);
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, instance_norm_output, output_qparams[0].scale,
output_qparams[0].zero_point, use_contrib_qdq);
};
}

Expand All @@ -65,7 +71,8 @@ static void RunInstanceNormQDQTest(const TestInputDef<float>& input_def,
const TestInputDef<float>& scale_def,
const TestInputDef<float>& bias_def,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
ExpectedEPNodeAssignment expected_ep_assignment) {
ExpectedEPNodeAssignment expected_ep_assignment,
bool use_contrib_qdq = false) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
Expand All @@ -75,11 +82,10 @@ static void RunInstanceNormQDQTest(const TestInputDef<float>& input_def,

// Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs.
TestQDQModelAccuracy(BuildOpTestCase<float>("InstanceNormalization", {input_def, scale_def, bias_def}, {}, attrs),
BuildQDQInstanceNormTestCase<QuantType>(input_def, scale_def, bias_def, attrs),
BuildQDQInstanceNormTestCase<QuantType>(input_def, scale_def, bias_def, attrs, use_contrib_qdq),
provider_options,
18,
expected_ep_assignment,
1e-5f);
expected_ep_assignment);
}

// Check that QNN compiles DQ -> InstanceNormalization -> Q as a single unit.
Expand All @@ -97,6 +103,19 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8) {
ExpectedEPNodeAssignment::All);
}

TEST_F(QnnHTPBackendTests, InstanceNormU16) {
std::vector<float> input_data = {3.21289f, -5.9981f, -1.72799f, 6.27263f, 3.36205f, -1.93515f, -5.40113f, 3.75648f, 6.15357f,
-5.25769f, 2.73637f, -0.901382f, -6.55612f, 1.99497f, -4.79228f, 2.69813f, 8.3064f, 0.0362501f};
std::vector<float> scale_data = {-0.148738f, -1.45158f};
std::vector<float> bias_data = {-2.2785083772f, 2.3338717017f};
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({1, 2, 3, 3}, false, input_data).OverrideValueRange(-10.0f, 10.0f),
TestInputDef<float>({2}, true, scale_data).OverrideValueRange(-2.0f, 2.0f),
TestInputDef<float>({2}, true, bias_data).OverrideValueRange(-3.0f, 3.0f),
{},
ExpectedEPNodeAssignment::All,
true); // Use contrib Q/DQ ops for 16bit support.
}

// Check that QNN compiles DQ -> InstanceNormalization -> Q as a single unit.
// Use an input of rank 3.
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3) {
Expand All @@ -107,6 +126,58 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3) {
ExpectedEPNodeAssignment::All);
}

// Test 8-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
// which requires wrapping the QNN InstanceNorm op with reshapes.
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1) {
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
RunInstanceNormQDQTest(TestInputDef<float>({2, 2, 3}, false, input_data),
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
TestInputDef<float>({2}, true, {1.0f, 3.0f}),
{},
ExpectedEPNodeAssignment::All);
}

// Test 16-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
// which requires wrapping the QNN InstanceNorm op with reshapes.
TEST_F(QnnHTPBackendTests, InstanceNormU16Rank3_BatchSizeNot1) {
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({2, 2, 3}, false, input_data),
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
TestInputDef<float>({2}, true, {1.0f, 3.0f}),
{},
ExpectedEPNodeAssignment::All,
true); // Use contrib Q/DQ ops for 16bit support.
}

// Test 8-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
// which requires wrapping the QNN InstanceNorm op with reshapes.
// Input 0 is an initializer.
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1_Initializer) {
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
RunInstanceNormQDQTest(TestInputDef<float>({2, 2, 3}, true, input_data),
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
TestInputDef<float>({2}, false, {1.0f, 3.0f}),
{},
ExpectedEPNodeAssignment::All);
}

// Test 16-bit QDQ InstanceNormalization with an input of rank 3 with N != 1,
// which requires wrapping the QNN InstanceNorm op with reshapes.
// Input 0 is an initializer.
TEST_F(QnnHTPBackendTests, InstanceNormU16Rank3_BatchSizeNot1_Initializer) {
std::vector<float> input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f,
-8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f};
RunInstanceNormQDQTest<uint16_t>(TestInputDef<float>({2, 2, 3}, true, input_data),
TestInputDef<float>({2}, true, {1.0f, 2.0f}),
TestInputDef<float>({2}, false, {1.0f, 3.0f}),
{},
ExpectedEPNodeAssignment::All,
true); // Use contrib Q/DQ ops for 16-bit support.
}

// Check that QNN InstanceNorm operator does not handle inputs with rank > 4.
TEST_F(QnnHTPBackendTests, InstanceNormU8Rank5) {
RunInstanceNormQDQTest(TestInputDef<float>({1, 2, 3, 3, 3}, false, -10.0f, 10.0f),
Expand Down