From cc4049af831034cd7d0b78e4b055b8994939e62a Mon Sep 17 00:00:00 2001 From: glen-amd <146770157+glen-amd@users.noreply.github.com> Date: Fri, 19 Jul 2024 08:34:03 -0700 Subject: [PATCH 1/5] Enabled more VitisAI backend compilers (#21411) ### Description Enabled more VitisAI backend compilers --- onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc | 2 +- onnxruntime/core/providers/vitisai/include/ep_context_utils.h | 4 ++-- .../core/providers/vitisai/vitisai_execution_provider.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc index ab31aa313cf6d..368c8c0358228 100644 --- a/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc +++ b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc @@ -466,7 +466,7 @@ std::string RetrieveEPContextCache( fs::path ep_ctx_fs_path(ep_ctx_model_loc); // Attr "ep_cache_context" stores a relative path. ep_ctx_fs_path.replace_filename(fs::path(ep_ctx_cache)); - // TODO: Validaion of the file location to make sure security is met. + // TODO: Validation of the file location to make sure security is met. if (!fs::exists(ep_ctx_fs_path) || !fs::is_regular_file(ep_ctx_fs_path)) { ORT_THROW("File for EP context cache is missing"); } diff --git a/onnxruntime/core/providers/vitisai/include/ep_context_utils.h b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h index 61a595cf1ae15..26546f422765c 100644 --- a/onnxruntime/core/providers/vitisai/include/ep_context_utils.h +++ b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h @@ -14,8 +14,8 @@ namespace fs = std::filesystem; namespace onnxruntime { constexpr const uint8_t kXCCode = 1; -constexpr const uint8_t kDDCode = 2; -constexpr const uint8_t kVCode = 4; +[[maybe_unused]] constexpr const uint8_t kDDCode = 2; +[[maybe_unused]] constexpr const uint8_t kVCode = 4; static constexpr const char* kEPContextOp = "EPContext"; static constexpr const char* kMainContextAttr = "main_context"; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index f45b89649bfcb..036831df7a9cf 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -86,7 +86,7 @@ void VitisAIExecutionProvider::PrepareEPContextEnablement( model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string()); } std::string backend_cache_dir, backend_cache_key; - get_backend_compilation_cache(model_path_str_, graph_viewer, info_, kXCCode, backend_cache_dir, backend_cache_key, backend_cache_data_); + get_backend_compilation_cache(model_path_str_, graph_viewer, info_, kXCCode | kDDCode | kVCode, backend_cache_dir, backend_cache_key, backend_cache_data_); info_["cacheDir"] = backend_cache_dir; info_["cacheKey"] = backend_cache_key; // Create a new model, reusing the graph name, the op-domain-to-opset-version map, From 22d4d82f3c55525510bef785fab6c7c83a21c2e9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 19 Jul 2024 08:36:47 -0700 Subject: [PATCH 2/5] Move ReluQuantFusion to Level2 for CPU EP only (#21329) ### Description Moves the `Relu -> QuantizeLinear` fusion to Level2 optimizations for CPU EP only. ### Motivation and Context See the related PR for motivation and context: https://github.com/microsoft/onnxruntime/pull/20627 --- .../core/optimizer/graph_transformer_utils.cc | 2 +- .../qdq_transformer/relu_quantizelinear.cc | 4 +- .../test/optimizer/qdq_transformer_test.cc | 51 +++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4298551aec412..e6feb3e7ddbe2 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -132,12 +132,12 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); - rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; case TransformerLevel::Level2: rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; diff --git a/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc index 7417212c570c8..e756ffe78a289 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc @@ -13,13 +13,15 @@ namespace onnxruntime { bool ReluQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || + !graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { return false; } // if Relu is followed by QuantizeLinear, it can be fused into QuantizeLinear potentially const auto& next_node = *node.OutputNodesBegin(); - if (!QDQ::MatchQNode(next_node)) { + if (!graph_utils::IsSupportedProvider(next_node, {kCpuExecutionProvider}) || + !QDQ::MatchQNode(next_node)) { return false; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1c77121ba9df1..1638851daf65a 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2763,6 +2763,57 @@ TEST(QDQTransformerTests, Clip) { } } +// Test that the ReluQuantFusion transformer only runs for optimization level >= 2. +TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) { + auto test_case = [&](TransformerLevel opt_level, int8_t zp) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 2, 2, 2}, + {-4, -3, -2, 0, 1, 2, 3, 4}); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, 1.0f, zp, dq_output); + + // add Relu + auto* relu_output = builder.MakeIntermediate(); + builder.AddNode("Relu", {dq_output}, {relu_output}); + + // add Q + DQ + auto* q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(relu_output, 1.0f, zp, q_output); + builder.AddDequantizeLinearNode(q_output, 1.0f, zp, output_arg); + }; + + auto check_relu_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + // Only fuse relu into Q if level >= 2 and zero_point == -128 for int8. + // Level1 graph: input -> DQ -> Relu -> Q -> DQ -> output + // Level2+ graph: input -> DQ -> output (QuantReluFusion + QDQFinalCleanupTransformer transformers) + const bool fuse_relu = (zp == -128) && + (opt_level == TransformerLevel::Level2 || opt_level == TransformerLevel::Level3); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], fuse_relu ? 0 : 1); + EXPECT_EQ(op_to_count["Relu"], fuse_relu ? 0 : 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], fuse_relu ? 1 : 2); + }; + + constexpr float epsilon = std::numeric_limits::epsilon(); + + TransformerTester(build_test_case, check_relu_graph, + TransformerLevel::Default, + opt_level, + 18, + epsilon, + epsilon); + }; + + test_case(TransformerLevel::Level1, -128); // Will not fuse Relu into QuantizeLinear due to level1 opt. + test_case(TransformerLevel::Level2, -128); // Will fuse Relu into QuantizeLinear. + test_case(TransformerLevel::Level3, -128); // Will fuse Relu into QuantizeLinear. + test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128 +} + TEST(QDQTransformerTests, Concat) { auto test_case = [&](const std::vector>& input_shapes, int64_t axis, From 01df8c787d872d5fffb3b48d70d9a0e3d6323e3c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 19 Jul 2024 11:11:30 -0700 Subject: [PATCH 3/5] [js/web] fix vulnerable version of dependencies (#21412) ### Description ``` # npm audit report socket.io 3.0.0 - 4.6.2 Severity: high socket.io has an unhandled 'error' event - https://github.com/advisories/GHSA-25hc-qcg6-38wj Depends on vulnerable versions of engine.io fix available via `npm audit fix` node_modules/socket.io ws 8.0.0 - 8.17.0 Severity: high ws affected by a DoS when handling a request with many HTTP headers - https://github.com/advisories/GHSA-3h5v-q93c-6h6q fix available via `npm audit fix` node_modules/ws engine.io 0.7.8 - 0.7.9 || 6.0.0 - 6.5.4 Depends on vulnerable versions of ws node_modules/engine.io socket.io-adapter 2.5.2 - 2.5.4 Depends on vulnerable versions of ws node_modules/socket.io-adapter 4 high severity vulnerabilities ``` --- js/web/package-lock.json | 126 ++++++++++++++++++++------------------- 1 file changed, 65 insertions(+), 61 deletions(-) diff --git a/js/web/package-lock.json b/js/web/package-lock.json index b802a4e8271a7..3cfc0457c6234 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -194,9 +194,9 @@ } }, "node_modules/@socket.io/component-emitter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", - "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", "dev": true }, "node_modules/@szmarczak/http-timer": { @@ -236,9 +236,9 @@ "dev": true }, "node_modules/@types/cors": { - "version": "2.8.13", - "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.13.tgz", - "integrity": "sha512-RG8AStHlUiV5ysZQKq97copd2UmVYw3/pRMLefISZ3S1hK104Cwm7iLQ3fTKx+lsUH2CE8FlLaYeEA2LSeqYUA==", + "version": "2.8.17", + "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz", + "integrity": "sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA==", "dev": true, "dependencies": { "@types/node": "*" @@ -1086,9 +1086,9 @@ } }, "node_modules/engine.io": { - "version": "6.4.2", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", - "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", + "version": "6.5.5", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.5.5.tgz", + "integrity": "sha512-C5Pn8Wk+1vKBoHghJODM63yk8MvrO9EWZUfkAt5HAqIgPE4/8FF0PEGHXtEd40l223+cE5ABWuPzm38PHFXfMA==", "dev": true, "dependencies": { "@types/cookie": "^0.4.1", @@ -1099,17 +1099,17 @@ "cookie": "~0.4.1", "cors": "~2.8.5", "debug": "~4.3.1", - "engine.io-parser": "~5.0.3", - "ws": "~8.11.0" + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1" }, "engines": { - "node": ">=10.0.0" + "node": ">=10.2.0" } }, "node_modules/engine.io-parser": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", - "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", "dev": true, "engines": { "node": ">=10.0.0" @@ -3020,35 +3020,37 @@ } }, "node_modules/socket.io": { - "version": "4.6.1", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.6.1.tgz", - "integrity": "sha512-KMcaAi4l/8+xEjkRICl6ak8ySoxsYG+gG6/XfRCPJPQ/haCRIJBTL4wIl8YCsmtaBovcAXGLOShyVWQ/FG8GZA==", + "version": "4.7.5", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.7.5.tgz", + "integrity": "sha512-DmeAkF6cwM9jSfmp6Dr/5/mfMwb5Z5qRrSXLpo3Fq5SqyU8CMF15jIN4ZhfSwu35ksM1qmHZDQ/DK5XTccSTvA==", "dev": true, "dependencies": { "accepts": "~1.3.4", "base64id": "~2.0.0", + "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.4.1", + "engine.io": "~6.5.2", "socket.io-adapter": "~2.5.2", - "socket.io-parser": "~4.2.1" + "socket.io-parser": "~4.2.4" }, "engines": { - "node": ">=10.0.0" + "node": ">=10.2.0" } }, "node_modules/socket.io-adapter": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.2.tgz", - "integrity": "sha512-87C3LO/NOMc+eMcpcxUBebGjkpMDkNBS9tf7KJqcDsmL936EChtVva71Dw2q4tQcuVC+hAUy4an2NO/sYXmwRA==", + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.5.tgz", + "integrity": "sha512-eLDQas5dzPgOWCk9GuuJC2lBqItuhKI4uxGgo9aIV7MYbk2h9Q6uULEh8WBzThoI7l+qU9Ast9fVUmkqPP9wYg==", "dev": true, "dependencies": { - "ws": "~8.11.0" + "debug": "~4.3.4", + "ws": "~8.17.1" } }, "node_modules/socket.io-parser": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.3.tgz", - "integrity": "sha512-JMafRntWVO2DCJimKsRTh/wnqVvO4hrfwOqtO7f+uzwsQMuxO6VwImtYxaQ+ieoyshWOTJyV0fA21lccEXRPpQ==", + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", "dev": true, "dependencies": { "@socket.io/component-emitter": "~3.1.0", @@ -3449,16 +3451,16 @@ "dev": true }, "node_modules/ws": { - "version": "8.11.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", - "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "engines": { "node": ">=10.0.0" }, "peerDependencies": { "bufferutil": "^4.0.1", - "utf-8-validate": "^5.0.2" + "utf-8-validate": ">=5.0.2" }, "peerDependenciesMeta": { "bufferutil": { @@ -3648,9 +3650,9 @@ "dev": true }, "@socket.io/component-emitter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", - "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", "dev": true }, "@szmarczak/http-timer": { @@ -3687,9 +3689,9 @@ "dev": true }, "@types/cors": { - "version": "2.8.13", - "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.13.tgz", - "integrity": "sha512-RG8AStHlUiV5ysZQKq97copd2UmVYw3/pRMLefISZ3S1hK104Cwm7iLQ3fTKx+lsUH2CE8FlLaYeEA2LSeqYUA==", + "version": "2.8.17", + "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz", + "integrity": "sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA==", "dev": true, "requires": { "@types/node": "*" @@ -4379,9 +4381,9 @@ } }, "engine.io": { - "version": "6.4.2", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", - "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", + "version": "6.5.5", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.5.5.tgz", + "integrity": "sha512-C5Pn8Wk+1vKBoHghJODM63yk8MvrO9EWZUfkAt5HAqIgPE4/8FF0PEGHXtEd40l223+cE5ABWuPzm38PHFXfMA==", "dev": true, "requires": { "@types/cookie": "^0.4.1", @@ -4392,14 +4394,14 @@ "cookie": "~0.4.1", "cors": "~2.8.5", "debug": "~4.3.1", - "engine.io-parser": "~5.0.3", - "ws": "~8.11.0" + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1" } }, "engine.io-parser": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", - "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", "dev": true }, "ent": { @@ -5862,32 +5864,34 @@ "dev": true }, "socket.io": { - "version": "4.6.1", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.6.1.tgz", - "integrity": "sha512-KMcaAi4l/8+xEjkRICl6ak8ySoxsYG+gG6/XfRCPJPQ/haCRIJBTL4wIl8YCsmtaBovcAXGLOShyVWQ/FG8GZA==", + "version": "4.7.5", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.7.5.tgz", + "integrity": "sha512-DmeAkF6cwM9jSfmp6Dr/5/mfMwb5Z5qRrSXLpo3Fq5SqyU8CMF15jIN4ZhfSwu35ksM1qmHZDQ/DK5XTccSTvA==", "dev": true, "requires": { "accepts": "~1.3.4", "base64id": "~2.0.0", + "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.4.1", + "engine.io": "~6.5.2", "socket.io-adapter": "~2.5.2", - "socket.io-parser": "~4.2.1" + "socket.io-parser": "~4.2.4" } }, "socket.io-adapter": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.2.tgz", - "integrity": "sha512-87C3LO/NOMc+eMcpcxUBebGjkpMDkNBS9tf7KJqcDsmL936EChtVva71Dw2q4tQcuVC+hAUy4an2NO/sYXmwRA==", + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.5.tgz", + "integrity": "sha512-eLDQas5dzPgOWCk9GuuJC2lBqItuhKI4uxGgo9aIV7MYbk2h9Q6uULEh8WBzThoI7l+qU9Ast9fVUmkqPP9wYg==", "dev": true, "requires": { - "ws": "~8.11.0" + "debug": "~4.3.4", + "ws": "~8.17.1" } }, "socket.io-parser": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.3.tgz", - "integrity": "sha512-JMafRntWVO2DCJimKsRTh/wnqVvO4hrfwOqtO7f+uzwsQMuxO6VwImtYxaQ+ieoyshWOTJyV0fA21lccEXRPpQ==", + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", "dev": true, "requires": { "@socket.io/component-emitter": "~3.1.0", @@ -6179,9 +6183,9 @@ "dev": true }, "ws": { - "version": "8.11.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", - "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "requires": {} }, From 6ffaaebb60cd43cf7749e67a9bb54c3bd2cc4efd Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 19 Jul 2024 13:58:54 -0700 Subject: [PATCH 4/5] [CUDA] Attention kernel provider option (#21344) ### Description * Add a cuda provider option `sdpa_kernel` to choose which attention kernel to run for testing purpose. * Allow dump which attention kernel is used per node. * Reserve a flag for cudnn flash attention which will be added soon. #### CUDA provider option sdpa_kernel Instead of setting environment variable, we also support setting it in provider option. Note that the setting is global per session. That could help performance testing of each kernel. #### Attention Kernel Debug Info Set an environment variable `ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1`, and ORT will print sdpa kernel used in each node: For example ``` ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1 ./onnxruntime_test_all --gtest_filter=MultiHeadAttentionTest* ``` It will show debug information of kernel used in testing: ``` [ RUN ] MultiHeadAttentionTest.SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=0 TRT_FUSED_ATTENTION=1 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=1 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1 Operator=MultiHeadAttention Node=node1 DataType=fp16 TRT_FUSED_ATTENTION=1 AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=1 TRT_FUSED_ATTENTION=0 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=0 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1 Operator=MultiHeadAttention Node=node1 DataType=fp16 EFFICIENT_ATTENTION=1 ``` In this test case, the debug info shows that one session uses trt fused attention and another session use efficient attention. --- cmake/onnxruntime_rocm_hipify.cmake | 2 + cmake/onnxruntime_unittests.cmake | 3 +- .../providers/cuda/cuda_provider_options.h | 1 + .../contrib_ops/cpu/bert/attention_common.h | 28 ++- .../contrib_ops/cuda/bert/attention.cc | 53 ++--- onnxruntime/contrib_ops/cuda/bert/attention.h | 4 +- .../cuda/bert/attention_kernel_options.cc | 166 +++++++++++++ .../cuda/bert/attention_kernel_options.h | 67 ++++++ .../cuda/bert/group_query_attention.cc | 30 +-- .../cuda/bert/group_query_attention.h | 2 + .../cuda/bert/multihead_attention.cc | 50 ++-- .../cuda/bert/multihead_attention.h | 3 +- .../contrib_ops/cuda/bert/packed_attention.cc | 33 ++- .../contrib_ops/cuda/bert/packed_attention.h | 9 +- .../cuda/bert/packed_multihead_attention.cc | 40 ++-- .../cuda/bert/packed_multihead_attention.h | 4 +- .../providers/cuda/cuda_execution_provider.h | 17 ++ .../cuda/cuda_execution_provider_info.cc | 4 + .../cuda/cuda_execution_provider_info.h | 4 + onnxruntime/core/providers/cuda/cuda_kernel.h | 6 + .../providers/cuda/cuda_provider_factory.cc | 2 + .../multihead_attention_op_test.cc | 4 +- .../attention_kernel_options_test.cc | 221 ++++++++++++++++++ .../test/python/onnxruntime_test_python.py | 2 + 24 files changed, 645 insertions(+), 110 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc create mode 100644 onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 2966a4624a966..a8c876d30873e 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -15,6 +15,8 @@ set(contrib_ops_excluded_files "bert/attention_softmax.h" "bert/attention_softmax.cu" "bert/attention_prepare_qkv.cu" + "bert/attention_kernel_options.h" + "bert/attention_kernel_options.cc" "bert/decoder_attention_impl.h" "bert/decoder_attention_impl.cu" "bert/decoder_masked_multihead_attention.h" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0159c35d1941b..38ed0b1640192 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_test_utils onnxruntime_common) target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) - target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_test_utils onnxruntime_common) if (MSVC) # Cutlass code has an issue with the following: # warning C4100: 'magic': unreferenced formal parameter diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 6d53760ab60b5..01a14de699dc4 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -38,4 +38,5 @@ struct OrtCUDAProviderOptionsV2 { int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not int use_tf32 = 1; // use TF32 + int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index a5b9c84c63eb9..55292b35e1e38 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -147,6 +147,23 @@ constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_ } // namespace sparse_attention namespace attention { + +enum class AttentionBackend : int { + FLASH_ATTENTION = 1, + EFFICIENT_ATTENTION = 2, + TRT_FUSED_ATTENTION = 4, + CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention. + MATH = 16, // unfused kernel cannot be disabled right now. + + // The following kernels might be deprecated in the future. + TRT_FLASH_ATTENTION = 32, + TRT_CROSS_ATTENTION = 64, + TRT_CAUSAL_ATTENTION = 128, +}; + +// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled). +constexpr const char* kEnableAttentionKernelDebugInfo = "ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"; + // Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; @@ -157,6 +174,9 @@ constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATT // Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels. constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION"; +// Environment variable to enable or disable cuDNN flash attention. +constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION"; + // Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled). constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION"; @@ -166,11 +186,15 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF // Environment variable to enable or disable flash attention. Default is 0 (enabled). constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; -// Minimum sequence length to enable memory efficient attention in FP32. -constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; +// Minimum sequence length to perfer memory efficient attention when data type is float32 +constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32"; + +// Default value for minimum sequence length to enable memory efficient attention in FP32. +constexpr int kDefaultMinSeqLenForEfficientAttentionFp32 = 256; // Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV"; + // Default value for the above setting. constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index d9907f09121d0..cacd65313ebcc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -3,7 +3,6 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention.h" #include "contrib_ops/cuda/bert/bert_padding.h" @@ -40,36 +39,17 @@ REGISTER_KERNEL_TYPED(MLFloat16) template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) { - disable_fused_self_attention_ = - sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); + kernel_options_ = this->GetAttentionKernelOptions(); - enable_trt_flash_attention_ = - sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); - enable_fused_causal_attention_ = - sizeof(T) == 2 && - ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = - ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention(); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = - sizeof(T) != 2 || - onnxruntime::ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); + + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); } template @@ -134,7 +114,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads, parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. - if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } // Allocate buffers @@ -220,7 +200,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == past && nullptr == present && (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); if (use_memory_efficient_attention) { @@ -231,6 +211,20 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_memory_efficient_attention = false; #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + + debug_info.Print("Attention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + cublasHandle_t cublas = GetCublasHandle(context); typedef typename ToCudaType::MappedType CudaT; @@ -268,7 +262,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_fused_cross_attention, use_memory_efficient_attention); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); - ; typedef typename ToCudaType::MappedType CudaT; AttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index acafb379d713f..0c7d3621f95ef 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -8,6 +8,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -27,9 +28,10 @@ class Attention final : public CudaKernel, public AttentionBase { bool enable_trt_flash_attention_; bool enable_fused_causal_attention_; bool disable_memory_efficient_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; + + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc new file mode 100644 index 0000000000000..28a095e68131e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#include +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/shared_library/provider_api.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" + +using namespace onnxruntime::contrib::attention; + +namespace onnxruntime { +void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { + if (value > 0) { + use_flash_attention_ = (value & static_cast(AttentionBackend::FLASH_ATTENTION)) > 0; + use_efficient_attention_ = (value & static_cast(AttentionBackend::EFFICIENT_ATTENTION)) > 0; + use_trt_fused_attention_ = (value & static_cast(AttentionBackend::TRT_FUSED_ATTENTION)) > 0; + use_cudnn_flash_attention_ = (value & static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0; + use_unfused_ = (value & static_cast(AttentionBackend::MATH)) > 0; + use_trt_flash_attention_ = (value & static_cast(AttentionBackend::TRT_FLASH_ATTENTION)) > 0; + use_trt_cross_attention_ = (value & static_cast(AttentionBackend::TRT_CROSS_ATTENTION)) > 0; + use_trt_causal_attention_ = (value & static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0; + } else { + use_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFlashAttention, false); + use_efficient_attention_ = !ParseEnvironmentVariableWithDefault(kDisableMemoryEfficientAttention, false); + use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedSelfAttention, false); + use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, false); + use_unfused_ = true; + use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableTrtFlashAttention, false); + use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedCrossAttention, false); + use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault(kEnableFusedCausalAttention, false); + } + + enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault(kEnableAttentionKernelDebugInfo, false); + + // When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing. + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + kMinSeqLenForFlashAttentionPackedQKV, + value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV); + + min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault( + kMinSeqLenForEfficientAttentionFp32, + value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32); + + if (use_build_flag) { + // Some kernels can be disabled at build time. If they are disabled, we should not use them. +#ifndef USE_FLASH_ATTENTION + use_flash_attention_ = false; +#endif + +#ifndef USE_MEMORY_EFFICIENT_ATTENTION + use_efficient_attention_ = false; +#endif + } +} + +void AttentionKernelOptions::InitializeOnce( + int sdpa_kernel, bool use_build_flag) { + std::call_once(this->initialize_once_flag_, [&]() { + this->Initialize(sdpa_kernel, use_build_flag); + if (this->enable_kernel_debug_info_) { + this->Print(); + } + }); +} + +void AttentionKernelOptions::Print() const { + std::stringstream sstream; + sstream << "AttentionKernelOptions:"; + sstream << " FLASH_ATTENTION=" << int(use_flash_attention_); + sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_); + sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_); + sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_); + sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_); + sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_); + sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_); + sstream << " MATH=" << int(use_unfused_); + + if (!use_unfused_) { + sstream << std::endl + << "Warning: Unfused kernel cannot be disabled right now. MATH=0 is ignored."; + } + + // Output text in Cyan color to make it easier to spot + std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl; +} + +// Classify the kernel used in TRT fused runner. +void AttentionKernelDebugInfo::SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length) { + if (causal) { + use_trt_causal_attention = true; + } else if (enable_trt_flash_attention && sequence_length >= contrib::cuda::kMinSequenceLengthFlashAttention) { + use_trt_flash_attention = true; + } else { + use_trt_fused_attention = true; + } +} + +void AttentionKernelDebugInfo::Print(const char* operator_name, + const std::string& node_name, + bool is_float16, + bool is_bfloat16) const { + std::stringstream sstream; + sstream << "Operator=" << operator_name; + + if (node_name.length() > 0) { + sstream << " Node=" << node_name; + } + + if (is_bfloat16) { + sstream << " DataType=bf16"; + } else if (is_float16) { + sstream << " DataType=fp16"; + } else { + sstream << " DataType=fp32"; + } + + if (use_flash_attention.has_value() && use_flash_attention.value()) { + sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value()); + } + + if (use_efficient_attention.has_value() && use_efficient_attention.value()) { + sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value()); + } + + if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { + sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value()); + } + + if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) { + sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value()); + } + + if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) { + sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value()); + } + + if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) { + sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value()); + } + + if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) { + sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value()); + } + + bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) || + (use_efficient_attention.has_value() && use_efficient_attention.value()) || + (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) || + (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) || + (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) || + (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) || + (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()); + + // Fall back to unfused when no fused kernel is enabled. + if (!use_fused) { + sstream << " MATH=1"; + } + + // Output text in Cyan color to make it easier to spot. + std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl; +} + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h new file mode 100644 index 0000000000000..bd7df5f490c76 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include + +namespace onnxruntime { +struct AttentionKernelDebugInfo { + std::optional use_flash_attention = std::nullopt; + std::optional use_efficient_attention = std::nullopt; + std::optional use_trt_fused_attention = std::nullopt; + std::optional use_cudnn_flash_attention = std::nullopt; + std::optional use_trt_flash_attention = std::nullopt; + std::optional use_trt_cross_attention = std::nullopt; + std::optional use_trt_causal_attention = std::nullopt; + void SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length); + void Print(const char* operator_name, const std::string& node_name, bool is_float16, bool is_bfloat16) const; +}; + +class AttentionKernelOptions { + public: + void InitializeOnce(int sdpa_kernel, bool use_build_flag); + + bool UseFlashAttention() const { return use_flash_attention_; } + bool UseEfficientAttention() const { return use_efficient_attention_; } + bool UseTrtFusedAttention() const { return use_trt_fused_attention_; } + bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; } + bool UseUnfusedAttention() const { return use_unfused_; } + bool UseTrtFlashAttention() const { return use_trt_flash_attention_; } + bool UseTrtCrossAttention() const { return use_trt_cross_attention_; } + bool UseTrtCausalAttention() const { return use_trt_causal_attention_; } + + bool AllowDebugInfo() const { return enable_kernel_debug_info_; } + + int MinSeqLenForFlashAttentionPackedQkv() const { return min_seq_len_for_flash_attention_packed_qkv_; } + int MinSeqLenForEfficientAttentionFp32() const { return min_seq_len_for_efficient_attention_fp32_; } + + protected: + void Print() const; + + void Initialize(int value, bool use_build_flag); + + private: + bool use_flash_attention_{true}; + bool use_efficient_attention_{true}; + bool use_trt_fused_attention_{true}; + bool use_cudnn_flash_attention_{false}; + bool use_unfused_{true}; + + bool use_trt_flash_attention_{true}; + bool use_trt_cross_attention_{true}; + + // Causal attention is disabled by default in #14732. + bool use_trt_causal_attention_{false}; + + bool enable_kernel_debug_info_{false}; + + int min_seq_len_for_flash_attention_packed_qkv_{0}; + + int min_seq_len_for_efficient_attention_fp32_{0}; + + std::once_flag initialize_once_flag_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 3b6ad238cc826..797f9b0a1ea47 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -52,20 +52,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); -#else - disable_flash_attention_ = true; -#endif + kernel_options_ = this->GetAttentionKernelOptions(); + + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION // Memory efficient attention only supports float and float16, not bfloat16. - disable_memory_efficient_attention_ = std::is_same::value || - ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = std::is_same::value || !kernel_options_->UseEfficientAttention(); + if (!disable_flash_attention_) { zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); } @@ -161,7 +154,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size); if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -201,6 +194,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto unpacked_qkv_buffer = GetScratchBuffer(0, context->GetComputeStream()); #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + + debug_info.Print("GroupQueryAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + // seqlens_k buffer size_t seqlens_k_bytes = 0; seqlens_k_bytes = sizeof(int) * parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 15573ece166fc..4ff5b0a59f021 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -6,6 +6,7 @@ #include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -32,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel { bool disable_memory_efficient_attention_; static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ba8b00df07e06..b96140f3897f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/providers/cuda/cuda_common.h" -#include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/multihead_attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" @@ -47,31 +46,16 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); - disable_fused_self_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); + kernel_options_ = this->GetAttentionKernelOptions(); - enable_trt_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); - disable_fused_cross_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false); + disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention(); // Allocate cache buffers constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast(kCumulatedSequenceLengthCacheMaxBatchSize) + 1); @@ -155,7 +139,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. if (use_flash_attention && key == nullptr && value == nullptr && - parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } // Allocate buffers @@ -229,9 +213,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } #if USE_MEMORY_EFFICIENT_ATTENTION + int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32(); bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 - parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32 || - parameters.kv_sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32; + parameters.sequence_length >= length_threshold || + parameters.kv_sequence_length >= length_threshold; bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; @@ -249,6 +234,21 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_memory_efficient_attention = false; #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_fp16_runner_ != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + + debug_info.Print("MultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. bool no_qkv_workspace = nullptr == value && diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 86a32c92ce003..26e38dbad9fd7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -8,6 +8,7 @@ #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -31,12 +32,12 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 0146cce30c7d1..a1149ddbf99f5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -33,12 +33,11 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) template -TrtFusedAttention::TrtFusedAttention() { - disable_fused_runner_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); - - enable_trt_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); +TrtFusedAttention::TrtFusedAttention(const OpKernelInfo& info) + : CudaKernel(info) { + kernel_options_ = this->GetAttentionKernelOptions(); + disable_fused_runner_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); } template @@ -86,7 +85,8 @@ template class TrtFusedAttention; template class TrtFusedAttention; template -PackedAttention::PackedAttention(const OpKernelInfo& info) : TrtFusedAttention(), CudaKernel(info) { +PackedAttention::PackedAttention(const OpKernelInfo& info) + : TrtFusedAttention(info) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); @@ -268,7 +268,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* relative_position_bias = context->Input(5); PackedAttentionParameters parameters; - parameters.use_tf32 = UseTF32(); + parameters.use_tf32 = this->UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -295,6 +295,19 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { } #endif + if (this->kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length); + } + + debug_info.Print("PackedAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + typedef typename ToCudaType::MappedType CudaT; CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -313,7 +326,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, this->UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool no_qkv_workspace = false; // need workspace to add bias @@ -341,7 +354,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_memory_efficient_attention = use_memory_efficient_attention; - return QkvToContext(device_prop, cublas, Stream(context), parameters, data); + return QkvToContext(device_prop, cublas, this->Stream(context), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index f00c112fc73d2..67b420764169a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -9,6 +9,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -17,14 +18,16 @@ namespace cuda { using namespace onnxruntime::cuda; template -class TrtFusedAttention { +class TrtFusedAttention : public CudaKernel { public: - TrtFusedAttention(); + TrtFusedAttention(const OpKernelInfo& info); protected: MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; protected: + const AttentionKernelOptions* kernel_options_; + bool disable_fused_runner_; bool enable_trt_flash_attention_; mutable std::unique_ptr fused_fp16_runner_; @@ -32,7 +35,7 @@ class TrtFusedAttention { }; template -class PackedAttention final : public TrtFusedAttention, public CudaKernel { +class PackedAttention final : public TrtFusedAttention { public: PackedAttention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 3fbbafc01254e..53e96fc732a33 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -35,30 +35,16 @@ REGISTER_KERNEL_TYPED(MLFloat16) template PackedMultiHeadAttention::PackedMultiHeadAttention(const OpKernelInfo& info) - : TrtFusedAttention(), CudaKernel(info) { + : TrtFusedAttention(info) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); scale_ = info.GetAttrOrDefault("scale", 0.0f); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || onnxruntime::ParseEnvironmentVariableWithDefault( - attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_flash_attention_ = sizeof(T) != 2 || !this->kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = onnxruntime::ParseEnvironmentVariableWithDefault( - attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = !this->kernel_options_->UseEfficientAttention(); } template @@ -228,7 +214,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* relative_position_bias = context->Input(6); PackedAttentionParameters parameters; - parameters.use_tf32 = UseTF32(); + parameters.use_tf32 = this->UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(), key, value, @@ -255,7 +241,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. if (use_flash_attention && key == nullptr && value == nullptr && - parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + parameters.sequence_length < this->kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } } @@ -271,11 +257,25 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; use_memory_efficient_attention = is_good_for_rpb && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } #endif + if (this->kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length); + } + + debug_info.Print("PackedMultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + typedef typename ToCudaType::MappedType CudaT; cublasHandle_t cublas = this->GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h index e30c603dc30aa..9b52a70fc6181 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h @@ -4,13 +4,14 @@ #pragma once #include "contrib_ops/cuda/bert/packed_attention.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { namespace cuda { template -class PackedMultiHeadAttention final : public TrtFusedAttention, public CudaKernel { +class PackedMultiHeadAttention final : public TrtFusedAttention { public: PackedMultiHeadAttention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; @@ -32,7 +33,6 @@ class PackedMultiHeadAttention final : public TrtFusedAttention, public CudaK bool disable_memory_efficient_attention_; bool disable_flash_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index f53779058a8af..9c8a8712ca51c 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -17,6 +17,10 @@ #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/tunable/cuda_tuning_context.h" +#ifndef DISABLE_CONTRIB_OPS +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#endif + namespace onnxruntime { void RunOnUnload(std::function function); @@ -80,6 +84,14 @@ class CUDAExecutionProvider : public IExecutionProvider { bool IsNHWCPreferred() const { return info_.prefer_nhwc; } bool UseTF32() const { return info_.use_tf32; } +#ifndef DISABLE_CONTRIB_OPS + // Attention kernel options parsed from sdpa_kernel cuda provider option. + const AttentionKernelOptions* GetAttentionKernelOptions() const { + attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true); + return &attention_kernel_options_; + } +#endif + ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); } @@ -110,6 +122,11 @@ class CUDAExecutionProvider : public IExecutionProvider { // the tuning context might be altered when calling into a TunableOp mutable cuda::tunable::CudaTuningContext tuning_context_; +#ifndef DISABLE_CONTRIB_OPS + // Attention kernel options parsed from sdpa_kernel cuda provider option. + mutable AttentionKernelOptions attention_kernel_options_; +#endif + class PerThreadContext final { public: PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index c96381e3e68b1..31cf991a34fc9 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -34,6 +34,7 @@ constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_s constexpr const char* kPreferNHWCMode = "prefer_nhwc"; constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; constexpr const char* kUseTF32 = "use_tf32"; +constexpr const char* kSdpaKernel = "sdpa_kernel"; } // namespace provider_option_names } // namespace cuda @@ -117,6 +118,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc) .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) + .AddAssignmentToReference(cuda::provider_option_names::kSdpaKernel, info.sdpa_kernel) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -170,6 +172,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, + {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, }; return options; @@ -192,6 +195,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, + {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 1cac3d1513698..0efad80f743df 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -79,6 +79,8 @@ struct CUDAExecutionProviderInfo { // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. bool use_tf32{true}; + int sdpa_kernel{0}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); @@ -91,6 +93,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { size_t value{0xbc9f1d34}; // seed // Bits: device_id (16), arena_extend_strategy/cudnn_conv_algo_search (reserved 2), boolean options (1 each) + // Do not exceed 32 bits here otherwise some bits will be lost in x86. size_t data = static_cast(info.device_id) ^ (static_cast(info.arena_extend_strategy) << 16) ^ (static_cast(info.cudnn_conv_algo_search) << 18) ^ @@ -109,6 +112,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { onnxruntime::HashCombine(info.gpu_mem_limit, value); onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); + onnxruntime::HashCombine(info.sdpa_kernel, value); // Memory pointers onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 288da23f35ec8..9d37a9775872f 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -94,6 +94,12 @@ class CudaKernel : public OpKernel { return provider_->UseTF32(); } +#ifndef DISABLE_CONTRIB_OPS + const AttentionKernelOptions* GetAttentionKernelOptions() const { + return provider_->GetAttentionKernelOptions(); + } +#endif + tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 7851da7fa91a3..b1d54e56ded4e 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -226,6 +226,7 @@ struct CUDA_Provider : Provider { info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; info.use_tf32 = params->use_tf32 != 0; + info.sdpa_kernel = params->sdpa_kernel; return std::make_shared(info); } @@ -260,6 +261,7 @@ struct CUDA_Provider : Provider { cuda_options.prefer_nhwc = internal_options.prefer_nhwc; cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; cuda_options.use_tf32 = internal_options.use_tf32; + cuda_options.sdpa_kernel = internal_options.sdpa_kernel; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index a61e917b41e51..f0255d7ece84e 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -394,8 +394,8 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu } #if USE_MEMORY_EFFICIENT_ATTENTION - if (data.sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32 || - data.kv_sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32) { + if (data.sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32 || + data.kv_sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32) { kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( diff --git a/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc b/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc new file mode 100644 index 0000000000000..b2e986f680763 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef DISABLE_CONTRIB_OPS + +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "test/util/include/scoped_env_vars.h" +#include "gtest/gtest.h" + +#include +#include + +using onnxruntime::AttentionKernelOptions; +using onnxruntime::contrib::attention::AttentionBackend; + +namespace onnxruntime { +namespace test { + +TEST(AttentionKernelOptionsTest, NonZeroValue) { + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION) | static_cast(AttentionBackend::EFFICIENT_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_TRUE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_FUSED_ATTENTION) | static_cast(AttentionBackend::MATH); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_TRUE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_TRUE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_TRUE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_CROSS_ATTENTION) | static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_TRUE(options.UseTrtCrossAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + // Test environment variables are ignored when option value is non-zero + // Test default min sequence lengths are zeros + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}}}; + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + // Test min sequence lengths can be parsed from environment variables when option value is non-zero + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"}, + {onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}}; + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256); + } +} + +// Test all environment variables take effect when option value is 0. +TEST(AttentionKernelOptionsTest, DefaultOptionWithEnvVar) { + constexpr int value = 0; + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"}, + {onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}}; + AttentionKernelOptions options; + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_TRUE(options.UseEfficientAttention()); + ASSERT_TRUE(options.UseTrtFusedAttention()); + ASSERT_TRUE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_TRUE(options.UseTrtFlashAttention()); + ASSERT_TRUE(options.UseTrtCrossAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256); +} + +// Test default min sequence lengths when environment variables are not set. +TEST(AttentionKernelOptionsTest, DefaultMinSeqLens) { + constexpr int value = 0; + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}}}; + AttentionKernelOptions options; + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), + onnxruntime::contrib::attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), + onnxruntime::contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32); +} + +} // namespace test +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e4814aa7fc033..892e7de8bb6ed 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -446,6 +446,8 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("use_tf32", ["1", "0"]) + test_get_and_set_option_with_values("sdpa_kernel", ["0", "1", "2"]) + option["gpu_external_alloc"] = "0" option["gpu_external_free"] = "0" option["gpu_external_empty_cache"] = "0" From 34cd2e8ed8688c42d00adda1d260ac787f76bf29 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 20 Jul 2024 09:35:05 +1000 Subject: [PATCH 5/5] Add CoreML ML Program Resize (#21370) ### Description Add CoreML ML Program Resize - refactor existing logic to try and simplify and share between NeuralNetwork and MLProgram checks - add handling for some new attributes - antialias and axes - should have been done when setting the CoreML EP max opset to 21 ### Motivation and Context Support priority models --- .../core/providers/coreml/builders/helper.cc | 18 +- .../core/providers/coreml/builders/helper.h | 3 +- .../coreml/builders/impl/base_op_builder.cc | 2 +- .../coreml/builders/impl/base_op_builder.h | 6 + .../coreml/builders/impl/resize_op_builder.cc | 607 +++++++++++++----- .../providers/coreml/builders/model_builder.h | 13 +- .../coreml/coreml_execution_provider.cc | 1 + .../builders/impl/resize_op_builder.cc | 25 +- onnxruntime/core/providers/utils.cc | 16 + onnxruntime/core/providers/utils.h | 5 + .../core/providers/xnnpack/tensor/resize.cc | 21 +- .../providers/cpu/tensor/resize_op_test.cc | 152 ++++- .../apple/coreml_supported_mlprogram_ops.md | 1 + 13 files changed, 671 insertions(+), 199 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index b8ebbd05a2a20..e1f148fa93e23 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -50,8 +50,8 @@ bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, } } -bool IsInputSupported(const Node& node, const NodeArg& input, - const OpBuilderInputParams& input_params, const logging::Logger& logger) { +bool IsInputSupported(const Node& node, const NodeArg& input, const OpBuilderInputParams& input_params, + const logging::Logger& logger, bool allow_empty_input) { if (!input.Exists()) { // optional input that is not provided return true; @@ -84,16 +84,10 @@ bool IsInputSupported(const Node& node, const NodeArg& input, return false; } - if (dim == 0) { - if (node.OpType() == "Resize" && &input == node.InputDefs()[1]) { - // one special case. Resize 'roi' input was originally a required input but is rarely used. - // ROI is not supported in the CoreML implementation so we will ignore the value, but is often added - // (at least in the unit tests) as an initializer with shape {0}. - } else { - LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name - << ", shape: " << Shape2String(shape); - return false; - } + if (dim == 0 && !allow_empty_input) { + LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name + << ", shape: " << Shape2String(shape); + return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h index 300de2dedd122..0acaa0dd8a4a3 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.h +++ b/onnxruntime/core/providers/coreml/builders/helper.h @@ -30,7 +30,8 @@ OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, const IOpBuilder* GetOpBuilder(const Node& node); bool IsInputSupported(const Node& node, const NodeArg& node_arg, const OpBuilderInputParams& input_params, - const logging::Logger& logger); + const logging::Logger& logger, + bool allow_empty_input = false); bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 83a572f4b60fa..2cae85a0a1c8d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -74,7 +74,7 @@ bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { for (const auto* input : node.InputDefs()) { - if (!IsInputSupported(node, *input, input_params, logger)) { + if (!IsInputSupported(node, *input, input_params, logger, allow_empty_tensor_as_input_)) { return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 4a23640d0f34c..071008520fbdc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -28,6 +28,10 @@ class BaseOpBuilder : public IOpBuilder { void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {} protected: + explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false) + : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { + } + // currently we only support float static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params, const logging::Logger& logger); @@ -50,6 +54,8 @@ class BaseOpBuilder : public IOpBuilder { virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const = 0; + + const bool allow_empty_tensor_as_input_; // some operators can handle ignoring an empty tensor as input }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 3400f09b4056f..65b5c17f2c6a6 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include +#include #include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/providers/common.h" +#include "core/providers/utils.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -18,6 +20,11 @@ namespace onnxruntime { namespace coreml { class ResizeOpBuilder : public BaseOpBuilder { + public: + // allow roi and scales potentially being empty inputs that are ignored during processing + ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {} + + private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -29,196 +36,382 @@ class ResizeOpBuilder : public BaseOpBuilder { // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing // We only support Resize opset 11+ here int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } + + bool SupportsMLProgram() const override { return true; } }; namespace { -bool GetResizeScales(const InitializedTensorSet& initializers, - const Node& node, std::vector& scales, - const logging::Logger&) { +std::vector GetAxes(const NodeAttrHelper& helper, size_t input_rank) { + auto axes = helper.Get("axes", std::vector{}); + if (axes.empty()) { + axes.resize(input_rank); + std::iota(axes.begin(), axes.end(), 0); + } else { + for (auto& value : axes) { + if (value < 0) { + value = HandleNegativeAxis(value, input_rank); + } + } + } + + return axes; +} + +bool GetValidatedResizeScales(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& input_shape, + const std::vector& axes, + std::vector& scales, + const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 3) + int64_t input_rank = input_shape.size(); + + if (input_shape[input_rank - 2] == -1 || input_shape[input_rank - 1] == -1) { + LOGS(logger, VERBOSE) << "Resize with 'scales' requires the H and W dimensions to have fixed values"; return false; + } - const auto& scales_tensor = *initializers.at(input_defs[2]->Name()); - if (scales_tensor.dims_size() != 1 || scales_tensor.dims()[0] != 4) + const auto* scales_tensor = graph_viewer.GetConstantInitializer(input_defs[2]->Name()); + if (!scales_tensor) { + LOGS(logger, VERBOSE) << "Resize 'scales' input must be a constant initializer"; return false; - Initializer unpacked_tensor(scales_tensor); + } + + Initializer unpacked_tensor(*scales_tensor); auto scales_data = unpacked_tensor.DataAsSpan(); - scales = std::vector{scales_data.begin(), scales_data.end()}; + scales.assign(scales_data.begin(), scales_data.end()); + + for (size_t idx = 0, end = axes.size(); idx < end; ++idx) { + auto axis = axes[idx]; + auto scale = scales[idx]; + if (axis < (input_rank - 2) && scale != 1.0f) { + LOGS(logger, VERBOSE) << "Resize only supports resizing the last two axes. Scale of axis " << axis << " is " + << scale; + return false; + } + } + return true; } -bool GetResizeOutputSizes(const InitializedTensorSet& initializers, - const Node& node, std::vector& sizes, - const logging::Logger&) { +bool GetValidatedResizeSizes(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& input_shape, + const std::vector& axes, + std::vector& sizes, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 4) - return false; + int64_t input_rank = input_shape.size(); - const auto& sizes_tensor = *initializers.at(input_defs[3]->Name()); - if (sizes_tensor.dims_size() != 1 || sizes_tensor.dims()[0] != 4) + const auto* sizes_tensor = graph_viewer.GetConstantInitializer(input_defs[3]->Name()); + if (!sizes_tensor) { + LOGS(logger, VERBOSE) << "Resize 'sizes' input must be a constant initializer"; return false; - Initializer unpacked_tensor(sizes_tensor); + } + + Initializer unpacked_tensor(*sizes_tensor); auto sizes_data = unpacked_tensor.DataAsSpan(); - sizes = std::vector(sizes_data.begin(), sizes_data.end()); + sizes.assign(sizes_data.begin(), sizes_data.end()); + + for (size_t idx = 0, end = axes.size(); idx < end; ++idx) { + auto axis = axes[idx]; + auto cur_size = input_shape[idx]; + auto new_size = sizes[idx]; + if (axis < (input_rank - 2) && cur_size != new_size) { + LOGS(logger, VERBOSE) << "Resize only supports resizing the last two axes. Input rank: " << input_rank + << " Change to size of axis " << axis << " from " << cur_size << " to " << new_size; + return false; + } + } + return true; } } // namespace void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - // We don't really use ROI here, so add it to skipped list if it's an initializer tensor - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI - model_builder.AddInputToSkip(node.InputDefs()[1]->Name()); // ROI - - // We will still add scales to the skipped list even sizes are present - // since there is no use of it, we will not process it later - model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // scales - model_builder.AddInputToSkip(node.InputDefs()[2]->Name()); // scales - - if (node.InputDefs().size() > 3) { - model_builder.AddInitializerToSkip(node.InputDefs()[3]->Name()); // sizes - model_builder.AddInputToSkip(node.InputDefs()[3]->Name()); // sizes + const auto& input_defs = node.InputDefs(); + + // In Resize-11 both roi and scales were required even if you were using sizes. + // https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11 + // From Resize-13 on they're all optional. + // + // We don't support roi so would never take a node with meaningful roi input. The roi input can however be provided + // and is ignored unless coordinate_transformation_mode is set to 'tf_crop_and_resize'. + // e.g. our unit tests tend to always provide an empty tensor as roi input instead of as a missing optional input. + // Due to this we always call AddInputToSkip on the roi input. + // + // We require the sizes or scales input to be a constant initializers to take the node (i.e. they won't be an input + // to the CoreML model for the partition, so calling AddInputToSkip isn't relevant). + // Individual values from scales and sizes are added directly to the layer, so we won't use the initializer. + // + // That leaves an edge case for Resize-11 where scales could have been provided as an empty input tensor but + // we're using a constant initializer for sizes. In this case AddInputToSkip needs to be called for the scales input. + + model_builder.AddInitializerToSkip(input_defs[1]->Name()); // roi + model_builder.AddInputToSkip(input_defs[1]->Name()); + + if (input_defs[2]->Exists()) { + model_builder.AddInitializerToSkip(input_defs[2]->Name()); // scales + } + + if (input_defs.size() > 3 && input_defs[3]->Exists()) { + model_builder.AddInitializerToSkip(input_defs[3]->Name()); // sizes + + if (node.SinceVersion() < 13) { + model_builder.AddInputToSkip(input_defs[2]->Name()); // skip the unused scales input + } } } -Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, +Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); + const auto input_defs = node.InputDefs(); + const auto output_defs = node.OutputDefs(); + const auto& graph_viewer = model_builder.GetGraphViewer(); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Error getting input shape"); + size_t input_rank = input_shape.size(); + + // we know we have either a scales or sizes input so this is safe. + // check for sizes first. this handles Resize-11 where scales was a required input but sizes were used if provided. + bool using_sizes = input_defs.size() >= 4 && input_defs[3]->Exists(); + bool using_scales = !using_sizes; - auto* coreml_upsample = layer->mutable_upsample(); NodeAttrHelper helper(node); - const auto mode = helper.Get("mode", "nearest"); - if (mode == "linear") { - coreml_upsample->set_mode(COREML_SPEC::UpsampleLayerParams_InterpolationMode_BILINEAR); - } else { // we already checked the mode must be NN or Bilinear in IsOpSupportedImpl - coreml_upsample->set_mode(COREML_SPEC::UpsampleLayerParams_InterpolationMode_NN); + const auto& mode = helper.Get("mode", "nearest"); + bool is_nearest = mode == "nearest"; + bool is_linear = !is_nearest; + + auto axes = GetAxes(helper, input_rank); + std::vector output_scales; + std::vector output_sizes; + size_t num_scales = 0; + size_t num_sizes = 0; + + if (using_scales) { + ORT_RETURN_IF_NOT(GetValidatedResizeScales(graph_viewer, node, input_shape, axes, output_scales, logger), + "Error getting validated scales"); + num_scales = output_scales.size(); + + // special case linear downsample. + // the CoreML implementation seems to be flaky and gives different outputs on different OS versions. + // use bilinear_resize instead. we check in IsOpSupportedImpl that the downsample input is evenly + // divisible by the output size so there's no rounding involved. + if (is_linear && (output_scales[num_scales - 1] < 1.f || output_scales[num_scales - 2] < 1.f)) { + using_scales = false; + using_sizes = true; + num_sizes = num_scales; + output_sizes = input_shape; + // only the last two dims have their size changed + output_sizes[input_rank - 2] = static_cast(input_shape[input_rank - 2] * output_scales[num_scales - 2]); + output_sizes[input_rank - 1] = static_cast(input_shape[input_rank - 1] * output_scales[num_scales - 1]); + } + } else { + ORT_RETURN_IF_NOT(GetValidatedResizeSizes(graph_viewer, node, input_shape, axes, output_sizes, logger), + "Error getting validated sizes"); + num_sizes = output_sizes.size(); } - const auto& input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - - if (input_defs.size() >= 3 && input_defs[2]->Exists()) { // use scales - std::vector scales; - ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); - coreml_upsample->add_scalingfactor(static_cast(scales[2])); - coreml_upsample->add_scalingfactor(static_cast(scales[3])); - } else { // we already checked number of inputs in IsOpSupportedImpl - std::vector input_shape; - ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Error getting input shape"); - std::vector output_sizes; - ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), - "Error getting resize output_sizes"); - coreml_upsample->add_scalingfactor(static_cast(output_sizes[2] / input_shape[2])); - coreml_upsample->add_scalingfactor(static_cast(output_sizes[3] / input_shape[3])); - } +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; // NOLINT + + std::string_view coreml_op_type; + if (using_scales) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.upsample_bilinear + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.upsample_nearest_neighbor + coreml_op_type = is_linear ? "upsample_bilinear" : "upsample_nearest_neighbor"; + } else { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.resize_bilinear + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.resize_nearest_neighbor + coreml_op_type = is_linear ? "resize_bilinear" : "resize_nearest_neighbor"; + } + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + + std::string coord_trans_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); + + if (using_scales) { + float scale_height = output_scales[num_scales - 2]; + float scale_width = output_scales[num_scales - 1]; + AddOperationInput(*op, "scale_factor_height", + model_builder.AddScalarConstant(coreml_op_type, "scale_factor_height", scale_height)); + AddOperationInput(*op, "scale_factor_width", + model_builder.AddScalarConstant(coreml_op_type, "scale_factor_width", scale_width)); + + if (is_linear) { + // we only allow these coord modes in the 'is supported' check, + // - half_pixel or pytorch_half_pixel with output size > 1 -> align_corners = false + // - align_corners -> align_corners = true + bool align_corners = coord_trans_mode == "align_corners"; + AddOperationInput(*op, "align_corners", + model_builder.AddScalarConstant(coreml_op_type, "align_corners", align_corners)); + } + } else { + assert(using_sizes); + int64_t target_height = output_sizes[num_sizes - 2]; + int64_t target_width = output_sizes[num_sizes - 1]; + + AddOperationInput(*op, "target_size_height", + model_builder.AddScalarConstant(coreml_op_type, "target_size_height", target_height)); + AddOperationInput(*op, "target_size_width", + model_builder.AddScalarConstant(coreml_op_type, "target_size_width", target_width)); + + if (is_linear) { + // we only allow these coord modes in the 'is supported' check, + // - half_pixel or pytorch_half_pixel with output size > 1 -> UNALIGN_CORNERS + // - align_corners -> STRICT_ALIGN_CORNERS + // - asymmetric -> DEFAULT + std::string sampling_mode_value; + if (coord_trans_mode == "asymmetric") { + sampling_mode_value = "DEFAULT"; + } else if (coord_trans_mode == "align_corners") { + sampling_mode_value = "STRICT_ALIGN_CORNERS"; + } else { + sampling_mode_value = "UNALIGN_CORNERS"; + } + + AddOperationInput(*op, "sampling_mode", + model_builder.AddScalarConstant(coreml_op_type, "sampling_mode", sampling_mode_value)); + } + } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + AddOperationOutput(*op, *output_defs[0]); + model_builder.AddOperation(std::move(op)); + } else // NOLINT +#endif + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto* coreml_upsample = layer->mutable_upsample(); + + // we already checked the mode must be NN or Bilinear in IsOpSupportedImpl + if (is_linear) { + coreml_upsample->set_mode(COREML_SPEC::UpsampleLayerParams_InterpolationMode_BILINEAR); + } else { + coreml_upsample->set_mode(COREML_SPEC::UpsampleLayerParams_InterpolationMode_NN); + } + + if (using_scales) { + coreml_upsample->add_scalingfactor(static_cast(output_scales[num_scales - 2])); + coreml_upsample->add_scalingfactor(static_cast(output_scales[num_scales - 1])); + } else { + auto scale_height = output_sizes[num_sizes - 2] / input_shape[input_rank - 2]; + auto scale_width = output_sizes[num_sizes - 1] / input_shape[input_rank - 1]; + coreml_upsample->add_scalingfactor(static_cast(scale_height)); + coreml_upsample->add_scalingfactor(static_cast(scale_width)); + } + + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = output_defs[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "Resize: input shape was not known"; return false; + } - const auto input_size = input_shape.size(); - if (input_size != 4) { - LOGS(logger, VERBOSE) << "Resize only support 4d shape, input is " - << input_size << "d shape"; + // as we allow empty shapes in the checks done by BaseOpBuilder::HasSupportedInputs we explicitly check for an empty + // an empty input here to be consistent. + // this should never happen in a real model though as a dim with value 0 (i.e. no input data) would typically be a + // dynamic dimension where a previous step had no output (e.g. Loop of zero interations, NonZero with no matches, + // NonMaxSupression with no boxes). + if (DoesShapeSpecifyZeroElements(input_shape)) { + LOGS(logger, VERBOSE) << "Resize input shape has with dimension values of 0 which is not supported."; return false; } - { // check attributes - NodeAttrHelper helper(node); - const auto mode = helper.Get("mode", "nearest"); - bool is_linear_resize = mode == "linear"; - bool is_nearest_resize = mode == "nearest"; - if (!is_linear_resize && !is_nearest_resize) { - LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode; + const auto input_rank = input_shape.size(); + if (input_params.create_mlprogram) { + if (input_rank < 3 || input_rank > 5) { + LOGS(logger, VERBOSE) << "Resize only supports 3D to 5D input. Got: " << input_rank << "D"; return false; } - - const auto exclude_outside = helper.Get("exclude_outside", 0); - if (exclude_outside != 0) { - LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; + } else { + if (input_rank != 4) { + LOGS(logger, VERBOSE) << "Resize only support 4d shape. Got: " << input_rank << "D"; return false; } + } - const auto coord_trans_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); - bool using_asymmetric = coord_trans_mode == "asymmetric"; - if (is_linear_resize) { - // TODO, add support of align_corners and half_pixel - if (!using_asymmetric) { - LOGS(logger, VERBOSE) << "Resize bilinear, unsupported coord_trans_mode, " << coord_trans_mode; - return false; - } - } else { - // nearest neighbor resizing - // For resize using nearest neighbor, we only support coord_trans_mode == "asymmetric" && nearest_mode == "floor" - if (!using_asymmetric) { - LOGS(logger, VERBOSE) << "Resize nearest neighbor, unsupported coord_trans_mode, " << coord_trans_mode; - return false; - } + // check attributes + NodeAttrHelper helper(node); - const auto nearest_mode = helper.Get("nearest_mode", "round_prefer_floor"); - if (nearest_mode != "floor") { - LOGS(logger, VERBOSE) << "Resize nearest neighbor, unsupported nearest_mode, " << nearest_mode; - return false; - } - } + if (helper.Get("antialias", 0) != 0) { + LOGS(logger, VERBOSE) << "Resize does not support antialias"; + return false; } - { // scales and sizes (if present) must be initializers - if (input_defs.size() < 3) { - LOGS(logger, VERBOSE) << "Input scales or sizes of Resize must be known"; - return false; - } + const auto& mode = helper.Get("mode", "nearest"); + bool is_linear = mode == "linear"; + bool is_nearest = mode == "nearest"; + if (!is_linear && !is_nearest) { + LOGS(logger, VERBOSE) << "Resize unsupported input mode: " << mode; + return false; + } - bool using_scales = input_defs.size() >= 3 && input_defs[2]->Exists(); - // scales - if (using_scales && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { - LOGS(logger, VERBOSE) << "scales input of Resize must be a constant initializer"; + if (is_nearest) { + const auto nearest_mode = helper.Get("nearest_mode", "round_prefer_floor"); + if (nearest_mode != "floor") { + LOGS(logger, VERBOSE) << "Resize only supports 'floor' nearest_mode. Got: " << nearest_mode; return false; } + } - // sizes - if (!using_scales && - (input_defs.size() < 4 || - !input_defs[3]->Exists() || - !input_params.graph_viewer.GetConstantInitializer(input_defs[3]->Name()))) { - LOGS(logger, VERBOSE) << "sizes input of Resize must be a constant initializer"; - return false; - } + if (helper.Get("exclude_outside", 0) != 0) { + LOGS(logger, VERBOSE) << "Resize does not support 'exclude_outside'"; + return false; + } - // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (using_scales) { - std::vector scales; - if (!GetResizeScales(initializers, node, scales, logger)) - return false; + const auto keep_aspect_ratio_policy = helper.Get("keep_aspect_ratio_policy", "stretch"); + if (keep_aspect_ratio_policy != "stretch") { + LOGS(logger, VERBOSE) << "Resize only supports keep_aspect_ratio_policy of 'stretch'. Got " + << keep_aspect_ratio_policy; + return false; + } - float scale_n = scales[0]; - float scale_c = scales[1]; - if (scale_n != 1.0f || scale_c != 1.0f) { - LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" - << "Resize of N/C channels are not supported" - << ", scale_n, " << scale_n << ", scale_c, " << scale_c; - return false; - } + // check for sizes first. this handles Resize-11 where scales was a required input but sizes were used if provided. + bool using_sizes = input_defs.size() >= 4 && input_defs[3]->Exists(); + bool using_scales = !using_sizes && input_defs.size() >= 3 && input_defs[2]->Exists(); - // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1 - // TODO support ResizeBilinear - float scale_h = scales[2]; - float scale_w = scales[3]; + if (!using_scales && !using_sizes) { + LOGS(logger, VERBOSE) << "Resize requires 'scales' or 'sizes' input"; + return false; + } + + // 'axes' is from opset 18 on and allows scales or sizes to have entries for the subset of axes. + // we fill with default values if necessary so that the processing is consistent across all supported opsets. + auto axes = GetAxes(helper, input_rank); + std::vector output_scales; + std::vector output_sizes; + + // make sure scales/sizes are constant initializers, and are only modifying the last two dimensions of the input. + if (using_scales) { + if (!GetValidatedResizeScales(input_params.graph_viewer, node, input_shape, axes, output_scales, logger)) { + return false; + } - // Onnx spec requires scale to be a positive float, so we are not checking that here + size_t num_scales = output_scales.size(); + float scale_h = output_scales[num_scales - 2]; + float scale_w = output_scales[num_scales - 1]; + + // NeuralNetwork supports upsample only with round numbers. + // + // ML Program results seem to match if round numbers are involved. When downsampling the scaling value should be + // 1 / . e.g. if input size is 8, scaling factor could be 1/8, 1/4 or 1/2. + if (scale_h >= 1.f && scale_w >= 1.f) { + // upsample (or no-op with both == 1.f that we won't bother special-casing) if (roundf(scale_h) != scale_h) { LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " is not a whole number"; return false; @@ -228,33 +421,57 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; return false; } - } else { - // we are using sizes - std::vector output_sizes; - if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) - return false; - - if (!IsStaticShape(input_shape)) { - LOGS(logger, VERBOSE) << "Input shape with dynamic dimensions is not supported."; + } else if (scale_h <= 1.f && scale_w <= 1.f) { + // downsample + if (input_params.create_mlprogram) { + auto h_in = input_shape[input_rank - 2]; + auto w_in = input_shape[input_rank - 1]; + + if (!utils::IsScalingByAFactorOfN(h_in, scale_h)) { + LOGS(logger, VERBOSE) << "Resize: downsampling scale " << scale_h + << " is not a factor of input height: " << h_in; + return false; + } + + if (!utils::IsScalingByAFactorOfN(w_in, scale_w)) { + LOGS(logger, VERBOSE) << "Resize: downsampling scale " << scale_w + << " is not a factor of input width: " << w_in; + return false; + } + + } else { + LOGS(logger, VERBOSE) << "Resize: downsampling is not supported."; return false; } + } else { + LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " and scale_w: " << scale_w + << " must both be >= 1 or <= 1"; + return false; + } + } else { + assert(using_sizes); + + if (!GetValidatedResizeSizes(input_params.graph_viewer, node, input_shape, axes, output_sizes, logger)) { + return false; + } - auto output_size_n = output_sizes[0]; - auto output_size_c = output_sizes[1]; - if (output_size_n != input_shape[0] || output_size_c != input_shape[1]) { - LOGS(logger, VERBOSE) << "Output sizes of N/C channel should match the input sizes, " - << "Resize of N/C channels are not supported" - << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n - << ". input_size_c, " << input_shape[1] << ", output_size_c, " << output_size_c; + if (input_params.create_mlprogram) { + // no additional requirements + } else { + if (!IsStaticShape(input_shape)) { + // need to convert from sizes to scales when creating the NN layer, so the input H and W are required + LOGS(logger, VERBOSE) << "Resize input shape with dynamic dimensions is not supported."; return false; } - // For now we only support upscale, so the output_size_h and output_size_w should be an integer >= 1 + // For now we only support upsample, so the output_size_h and output_size_w should be an integer >= 1 // TODO support ResizeBilinear - auto output_size_h = output_sizes[2]; - auto output_size_w = output_sizes[3]; - auto input_size_h = input_shape[2]; - auto input_size_w = input_shape[3]; + auto input_size_h = input_shape[input_rank - 2]; + auto input_size_w = input_shape[input_rank - 1]; + + auto num_sizes = output_sizes.size(); // could be smaller than input_rank if axes was used + auto output_size_h = output_sizes[num_sizes - 2]; + auto output_size_w = output_sizes[num_sizes - 1]; // Onnx spec requires output sizes to be a positive integer, so we are not checking that here if (output_size_h % input_size_h != 0) { @@ -271,6 +488,92 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa } } + std::string coord_trans_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); + bool using_asymmetric = coord_trans_mode == "asymmetric"; + + if (input_params.create_mlprogram) { + if (is_nearest) { + // Potential CoreML operators we could map to: + // + // image_resizing.upsample_nearest_neighbor + // - mode: nearest + // - coordinate_transformation_mode: asymmetric + // - 'scales' input + // + // image_resizing.resize_nearest_neighbor + // - mode: nearest + // - coordinate_transformation_mode: asymmetric + // - 'sizes' input + if (!using_asymmetric) { + LOGS(logger, VERBOSE) << "Resize with 'mode' of 'nearest' requires 'coordinate_transformation_mode' of " + "'asymmetric' . Got: " + << coord_trans_mode; + return false; + } + } else { + assert(is_linear); + // Potential CoreML operators we could map to: + // + // image_resizing.upsample_bilinear + // - mode: linear + // - 'scales' input + // - coordinate_transformation_mode + // - half_pixel -> align_corners = false + // - align_corners -> align_corners = true + // + // image_resizing.resize_bilinear + // - mode: linear + // - 'sizes' input + // - coordinate_transformation_mode -> sampling_mode + // - half_pixel -> UNALIGN_CORNERS + // - align_corners -> STRICT_ALIGN_CORNERS + // - asymmetric -> DEFAULT + // + + // if output size != 1, coordinate_transformation_mode of pytorch_half_pixel is the same as half_pixel + if (coord_trans_mode == "pytorch_half_pixel") { + int64_t h_out{0}, w_out{0}; + if (using_scales) { + size_t num_scales = output_scales.size(); + h_out = std::llround(input_shape[input_rank - 2] * output_scales[num_scales - 2]); + w_out = std::llround(input_shape[input_rank - 1] * output_scales[num_scales - 1]); + } else { + size_t num_sizes = output_sizes.size(); + h_out = output_sizes[num_sizes - 2]; + w_out = output_sizes[num_sizes - 1]; + } + + if (h_out > 1 && w_out > 1) { + coord_trans_mode = "half_pixel"; + } + } + + if (coord_trans_mode == "half_pixel" || + coord_trans_mode == "align_corners" || + (using_sizes && coord_trans_mode == "asymmetric")) { + // supported + + // FWIW we could calculate (if shape inferencing didn't already) the output sizes and convert a node with + // `scales` and co-ord mode of `asymmetric` to having `sizes` input so it's supported. + } else { + LOGS(logger, VERBOSE) << "Resize with 'mode' of 'linear' requires 'coordinate_transformation_mode' of " + "'half_pixel', or 'align_corners', or 'asymmetric' with 'sizes' input. Got: " + << coord_trans_mode; + + return false; + } + } + } else { + // NeuralNetwork checks + if (!using_asymmetric) { + // align_corners and half_pixel could be supported in ResizeBilinear but as NeuralNetwork is deprecated + // there's no known value to adding that. + LOGS(logger, VERBOSE) << "Resize only supports 'asymmetric' coordinate_transformation_mode. Got: " + << coord_trans_mode; + return false; + } + } + return true; } diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index 8f85ab2c09e7c..385588dbfdcb8 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -141,8 +141,17 @@ class ModelBuilder { // so we don't do a copy of the original initializer into the model. void AddInitializerToSkip(const std::string& tensor_name); - // There are some input which will not be used, add it to a list which will not - // be added to CoreML model, since CoreML does not like input unused + /// + /// Skip a non-initializer value, that is not used in the CoreML model, but was an input to a supported node. + /// + /// This is for a rare edge case where a value is an input to a node but is empty/unused, as the + /// CoreML model requires all model inputs to be consumed. + /// + /// + /// The only known use case for this currently is Resize, and that is largely due to how the unit tests are + /// setup rather than something you'd expect to see in a real model. + /// See ResizeOpBuilder::AddInitializersToSkip for more details. + /// void AddInputToSkip(const std::string& input_name); const std::string& GetUniqueName(const std::string& base_name); diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 0ba715cc7c6d9..a92fef81ac395 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -27,6 +27,7 @@ CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, coreml_flags_(coreml_flags), coreml_version_(coreml::util::CoreMLVersion()) { + LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_; if (coreml_version_ < MINIMUM_COREML_VERSION) { LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform."; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index d75b9cc72ff4b..ef27f6c942f44 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -9,6 +9,7 @@ #include "core/graph/graph_viewer.h" #include "core/optimizer/initializer.h" #include "core/providers/common.h" +#include "core/providers/utils.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" #include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h" @@ -251,14 +252,34 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const N const Initializer unpacked_tensor(*scales); auto scales_data = unpacked_tensor.DataAsSpan(); input_is_nchw = scales_data[1] == 1.0F; - float const scale_n = scales_data[0]; - float const scale_c = input_is_nchw ? scales_data[1] : scales_data[3]; + const float scale_n = scales_data[0]; + const float scale_c = input_is_nchw ? scales_data[1] : scales_data[3]; + const float scale_h = input_is_nchw ? scales_data[2] : scales_data[1]; + const float scale_w = input_is_nchw ? scales_data[3] : scales_data[2]; + if (scale_n != 1.0f || scale_c != 1.0f) { LOGS_DEFAULT(VERBOSE) << "Scales of N/C channel should be 1" << "Resize of N/C channels are not supported" << ", scale_n, " << scale_n << ", scale_c, " << scale_c; return false; } + + // if downsampling the input size must be evenly divisible by the output size to match the onnx output + if (scale_h < 1.0f || scale_w < 1.0f) { + // we also require input_shape to be known to check + auto h_in = input_is_nchw ? input_shape[2] : input_shape[1]; + auto w_in = input_is_nchw ? input_shape[3] : input_shape[2]; + if (h_in == 0 || w_in == 0) { + LOGS_DEFAULT(VERBOSE) << "Input H and W must be known to downsample with scales"; + return false; + } + + if (!utils::IsScalingByAFactorOfN(h_in, scale_h) || + !utils::IsScalingByAFactorOfN(w_in, scale_w)) { + LOGS_DEFAULT(VERBOSE) << "Input size must be evenly divisible by output size when downsampling"; + return false; + } + } } else { const auto* sizes = graph_viewer.GetConstantInitializer(inputs[3].node_arg.Name()); if (!sizes) { diff --git a/onnxruntime/core/providers/utils.cc b/onnxruntime/core/providers/utils.cc index b2f9d265ca053..747b09e42aa21 100644 --- a/onnxruntime/core/providers/utils.cc +++ b/onnxruntime/core/providers/utils.cc @@ -23,5 +23,21 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& return Status::OK(); } #endif + +bool IsScalingByAFactorOfN(int64_t n, float scale) { + bool is_factor = false; + if (scale > 0.f && scale < 1.f) { + const double factor = 1.0 / scale; + const double factor_rounded = std::round(factor); + constexpr double epsilon = 1.0e-4; // arbitrarily small enough + if (std::abs(factor - factor_rounded) < epsilon) { + // result is integer. check if a factor of n + const int64_t factor_i = static_cast(factor_rounded); + is_factor = n % factor_i == 0; + } + } + + return is_factor; +} } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/providers/utils.h b/onnxruntime/core/providers/utils.h index 8cafdb8c05cc3..9ea8496a02f85 100644 --- a/onnxruntime/core/providers/utils.h +++ b/onnxruntime/core/providers/utils.h @@ -15,5 +15,10 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& OpKernelContext* context, int output_index); #endif +/// +/// Check if the reciprocal of 'scale' is a factor of 'n'. +/// e.g. a scale of 0.5 is 1/2, the reciprocal is 2, and 2 is a factor of any even number. +/// +bool IsScalingByAFactorOfN(int64_t n, float scale); } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 09666c8039402..c752b5f849808 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -11,6 +11,7 @@ #include "core/framework/op_kernel.h" #include "core/optimizer/initializer.h" #include "core/providers/xnnpack/xnnpack_init.h" +#include "core/providers/utils.h" namespace onnxruntime { namespace xnnpack { @@ -68,9 +69,27 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, InlinedVector scale(4, 1.0F); if (scale_tensor) { const Initializer scale_val(*scale_tensor, node_unit.ModelPath()); - if (scale_val.DataAsSpan()[1] != 1.0F) { + const auto scales = scale_val.DataAsSpan(); + if (scales[1] != 1.0F) { break; } + + // downsampling output seems to require the output size to be a factor of the input to match ONNX + if (scales[2] < 1.0f || scales[3] < 1.0f) { + // we also require input_shape to be known to check + int64_t h_in = x_shape->dim(2).dim_value(); + int64_t w_in = x_shape->dim(3).dim_value(); + if (h_in < 0 || w_in < 0) { + break; + } + + float scale_h = scales[2]; + float scale_w = scales[3]; + if (!utils::IsScalingByAFactorOfN(h_in, scale_h) || + !utils::IsScalingByAFactorOfN(w_in, scale_w)) { + break; + } + } } if (size_tensor) { diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 496f2213e9d32..111520ef03e26 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -227,28 +227,33 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e } TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { - OpTester test("Resize", 13); - std::vector roi{}; - std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; + auto run_test = [](bool scales_in_initializer) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; - test.AddAttribute("mode", "linear"); + test.AddAttribute("mode", "linear"); - constexpr int64_t N = 1, C = 1, H = 2, W = 4; - std::vector X = { - 1.0f, 2.0f, 3.0f, 4.0f, - 5.0f, 6.0f, 7.0f, 8.0f}; + constexpr int64_t N = 1, C = 1, H = 2, W = 4; + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; - test.AddInput("X", {N, C, H, W}, X); - test.AddInput("roi", {0}, roi); - test.AddInput("scales", {4}, scales); + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {4}, scales, scales_in_initializer); - std::vector Y = {2.66666651f, 4.3333331f}; + std::vector Y = {2.66666651f, 4.3333331f}; - test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); - // QNN: result diff - // TRT: Segmentation fault in A100 - std::unordered_set excluded_providers({kQnnExecutionProvider}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers)); + test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); + // QNN: result diff + // TRT: Segmentation fault in A100 + std::unordered_set excluded_providers({kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers)); + }; + + run_test(false); + run_test(true); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) { @@ -327,13 +332,14 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { // Since NNAPI(TFLite) only using the scale calculate using the input/output size // For the above test (ResizeOpLinearDownSampleTest_4DBilinear) // The output size is [1,1,2,4].*[1,1,0.6,0.6]=[1,1,1,2] -// NNAPI will recaluclate the scales as the output size divided by input size +// NNAPI will recalculate the scales as the output size divided by input size // scales = [1,1,1,2]./[1,1,2,4] = [1,1,0.5,0.5] // See:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h // So the result of the above example will be different than CPU EP -// Add the following 2 tests to test with scales valid to NNAPI +// Add the following 2 tests to test with scales valid to NNAPI. +// CoreML also doesn't handle a scale that doesn't divide the input size evenly. TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -360,8 +366,38 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) { run_test(true); } +// Downsize with factor being an odd number (1/3) +TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1_OddNumber) { + // To test NNAPI EP, we need the scales/sizes to be in initializers + auto run_test = [](bool scales_in_initializer) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 1.0f, (1.f / 3), (1.f / 3)}; + + test.AddAttribute("mode", "linear"); + + constexpr int64_t N = 1, C = 1, H = 3, W = 6; + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f}; + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {4}, scales, scales_in_initializer); + + std::vector Y = {8.f, 11.f}; + + test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); + }; + + run_test(false); + run_test(true); +} + TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1_WithSizes) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_and_sizes_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -389,8 +425,32 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1_WithSizes) { run_test(true); } +// test handling for opset 11. scales input is provided but should be ignored in favor of sizes +TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1_WithSizesOpset11) { + OpTester test("Resize", 11); + std::vector roi{}; + std::vector scales{}; + constexpr int64_t N = 1, C = 1, H = 2, W = 4; + std::vector sizes{N, C, 1, 2}; + test.AddAttribute("mode", "linear"); + + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {0}, scales); + test.AddInput("sizes", {4}, sizes, true); // add as initializer so CoreML EP can take + + std::vector Y = {3.5f, 5.5f}; + + test.AddOutput("Y", sizes, Y); + test.Run(); +} + TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -416,15 +476,51 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners) { run_test(false); -#ifdef USE_NNAPI - // NNAPI will need the scales as an initializer +#if defined(USE_NNAPI) || defined(USE_COREML) + // NNAPI and CoreML need the scales as an initializer + // Also tensor RT EP will fail if scales is an initializer but will pass if it is not + run_test(true); +#endif +} + +TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners_sizes) { + // To test NNAPI EP, we need the scales/sizes to be in initializers + auto run_test = [](bool scales_in_initializer) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{}; + std::vector sizes{1, 1, 1, 2}; + + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", "align_corners"); + + constexpr int64_t N = 1, C = 1, H = 2, W = 4; + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("", {0}, scales); + test.AddInput("sizes", {4}, sizes, scales_in_initializer); + + std::vector Y = {1.0f, 4.0f}; + + test.AddOutput("Y", {N, C, 1, 2}, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); + }; + + run_test(false); + +#if defined(USE_NNAPI) || defined(USE_COREML) + // NNAPI and CoreML will need the scales as an initializer // Also tensor RT EP will fail if scales is an initializer but will pass if it is not run_test(true); #endif } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uint8) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -456,7 +552,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int8) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -622,7 +718,7 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric_scales) { } TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -668,7 +764,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { } TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_int8) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 1bbb933f66ba4..3b3790ba06599 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -17,6 +17,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Pow|Only supports cases when both inputs are fp32.| |ai.onnx:Relu|| |ai.onnx:Reshape|| +|ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| |ai:onnx:Tanh||