Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN EP] Fix a bug that can't create context binary if the model has inputs/outputs with different data type #18722

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1599,14 +1599,14 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Inputs (1 - ∞)

<dl>
<dt><tt>inputs</tt> (variadic) : T</dt>
<dt><tt>inputs</tt> (variadic, heterogeneous) : T</dt>
<dd>List of tensors for inputs</dd>
</dl>

#### Outputs (1 - &#8734;)

<dl>
<dt><tt>outputs</tt> (variadic) : T</dt>
<dt><tt>outputs</tt> (variadic, heterogeneous) : T</dt>
<dd>One or more outputs, list of tensors for outputs</dd>
</dl>

Expand Down
10 changes: 3 additions & 7 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3248,7 +3248,7 @@ void RegisterContribSchemas() {
"List of tensors for inputs",
"T",
OpSchema::Variadic,
true,
false,
1,
OpSchema::NonDifferentiable)
.Output(
Expand All @@ -3257,7 +3257,7 @@ void RegisterContribSchemas() {
"One or more outputs, list of tensors for outputs",
"T",
OpSchema::Variadic,
true,
false,
1,
OpSchema::NonDifferentiable)
.TypeConstraint(
Expand All @@ -3273,11 +3273,7 @@ void RegisterContribSchemas() {
"tensor(float16)",
"tensor(float)",
"tensor(double)"},
"Constrain input and output types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
});
"Constrain input and output types.");

static const char* BitmaskDropout_ver1_doc = R"DOC(
BitmaskDropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar).
Expand Down
72 changes: 72 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,78 @@
"high"); // qnn_context_priority
}

// Create a model with Case + Add (quantized)
// cast_input -> Cast -> Q -> DQ \
// Add -> Q -> DQ -> output
// input2 -> Q -> DQ /
static GetTestModelFn BuildCastAddTestCase() {
return [](ModelTestBuilder& builder) {
// Creat Cast node int32 -> float32
NodeArg* cast_input = MakeTestInput(builder, TestInputDef<int32_t>({2, 3}, false, {0, 1, 0, 1, 0, 1}));

auto* cast_output = builder.MakeIntermediate();
Node& cast_node = builder.AddNode("Cast", {cast_input}, {cast_output});
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT));

Check warning on line 350 in onnxruntime/test/providers/qnn/qnn_basic_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/providers/qnn/qnn_basic_test.cc#L350

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/test/providers/qnn/qnn_basic_test.cc:350:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

// Create Add node
std::vector<float> data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f};
gsl::span<float> data_range = gsl::make_span(data);
QuantParams<uint8_t> q_parameter = GetDataQuantParams<uint8_t>(data_range);
auto* add_input1_qdq = AddQDQNodePair<uint8_t>(builder, cast_output, q_parameter.scale, q_parameter.zero_point);

NodeArg* add_input2 = MakeTestInput(builder, TestInputDef<float>({2, 3}, false, data));
auto* add_input2_qdq = AddQDQNodePair<uint8_t>(builder, add_input2, q_parameter.scale, q_parameter.zero_point);

auto* add_output = builder.MakeIntermediate();

builder.AddNode("Add", {add_input1_qdq, add_input2_qdq}, {add_output});

// add_output -> Q -> DQ -> output
AddQDQNodePairWithOutputAsGraphOutput<uint8_t>(builder, add_output, q_parameter.scale, q_parameter.zero_point);
};
}

// Test that models with 2 inputs which has different data type can still generate the context binary
TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["qnn_context_cache_enable"] = "1";
const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx";
provider_options["qnn_context_cache_path"] = context_binary_file;

RunQnnModelTest(BuildCastAddTestCase(),
provider_options,
13, // opset
ExpectedEPNodeAssignment::All,
1e-5f,
logging::Severity::kERROR,
false);

// Make sure the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
}

// A repro of QC case 06838696, accuracy issue for Cast + Op (quantized)
// the value pair(1, 0.00392156886) at index #1 don't match,
// which is -0.996078 from 1
TEST_F(QnnHTPBackendTests, DISABLED_CastAddHTPAccuracyTest) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif

RunQnnModelTest(BuildCastAddTestCase(),
provider_options,
13, // opset
ExpectedEPNodeAssignment::All);
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
#endif // !defined(ORT_MINIMAL_BUILD)

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/providers/qnn/qnn_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void TryEnableQNNSaver(ProviderOptions& qnn_options) {

void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options,
int opset_version, ExpectedEPNodeAssignment expected_ep_assignment,
float fp32_abs_err, logging::Severity log_severity) {
float fp32_abs_err, logging::Severity log_severity, bool verify_outputs) {
EPVerificationParams verification_params;
verification_params.ep_node_assignment = expected_ep_assignment;
verification_params.fp32_abs_err = fp32_abs_err;
Expand All @@ -84,7 +84,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions prov
TryEnableQNNSaver(provider_options);
RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID",
QnnExecutionProviderWithOptions(provider_options),
helper.feeds_, verification_params);
helper.feeds_, verification_params, {}, verify_outputs);
}

void InferenceModel(const std::string& model_data, const char* log_id,
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/test/providers/qnn/qnn_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,9 @@ inline GetTestQDQModelFn<QuantType> BuildQDQOpTestCase(const std::string& op_typ
*/
void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options,
int opset_version, ExpectedEPNodeAssignment expected_ep_assignment,
float fp32_abs_err = 1e-5f, logging::Severity log_severity = logging::Severity::kERROR);
float fp32_abs_err = 1e-5f,
logging::Severity log_severity = logging::Severity::kERROR,
bool verify_outputs = true);

enum class BackendSupport {
SUPPORT_UNKNOWN,
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/util/include/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ void RunAndVerifyOutputsWithEP(ModelPathOrBytes model_path_or_bytes,
std::unique_ptr<IExecutionProvider> execution_provider,
const NameMLValMap& feeds,
const EPVerificationParams& params = EPVerificationParams(),
const std::function<void(SessionOptions&)>& session_options_updater = {});
const std::function<void(SessionOptions&)>& session_options_updater = {},
bool verify_outputs = true);

// Tests model loading only.
// This can be used to test EPs in builds where only loading (and not running) of a model is supported.
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/test/util/test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ void RunAndVerifyOutputsWithEP(ModelPathOrBytes model_path_or_bytes, std::string
std::unique_ptr<IExecutionProvider> execution_provider,
const NameMLValMap& feeds,
const EPVerificationParams& params,
const std::function<void(SessionOptions&)>& session_options_updater) {
const std::function<void(SessionOptions&)>& session_options_updater,
bool verify_outputs) {
std::vector<std::byte> model_data_buffer{};
const auto model_data = GetModelBytes(model_path_or_bytes, model_data_buffer);

Expand Down Expand Up @@ -184,7 +185,9 @@ void RunAndVerifyOutputsWithEP(ModelPathOrBytes model_path_or_bytes, std::string
// Run with EP and verify the result
std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session_object2.Run(run_options, feeds, output_names, &fetches));
VerifyOutputs(output_names, expected_fetches, fetches, params);
if (verify_outputs) {
VerifyOutputs(output_names, expected_fetches, fetches, params);
}

if (params.graph_verifier) {
(*params.graph_verifier)(graph2);
Expand Down
Loading