diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 2966a4624a966..a8c876d30873e 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -15,6 +15,8 @@ set(contrib_ops_excluded_files "bert/attention_softmax.h" "bert/attention_softmax.cu" "bert/attention_prepare_qkv.cu" + "bert/attention_kernel_options.h" + "bert/attention_kernel_options.cc" "bert/decoder_attention_impl.h" "bert/decoder_attention_impl.cu" "bert/decoder_masked_multihead_attention.h" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0159c35d1941b..38ed0b1640192 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_test_utils onnxruntime_common) target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) - target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_test_utils onnxruntime_common) if (MSVC) # Cutlass code has an issue with the following: # warning C4100: 'magic': unreferenced formal parameter diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e609745b5e03f..0bb5c7432f0a7 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -10,6 +10,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" #include "core/optimizer/graph_transformer.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/optimizer/rule_based_graph_transformer.h" @@ -49,7 +50,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -78,7 +80,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 6d53760ab60b5..01a14de699dc4 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -38,4 +38,5 @@ struct OrtCUDAProviderOptionsV2 { int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not int use_tf32 = 1; // use TF32 + int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option }; diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index d008058821be3..816eaaf9bc71a 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -16,6 +16,7 @@ struct OrtTensorRTProviderOptionsV2 { int device_id{0}; // cuda device id. int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream. void* user_compute_stream{nullptr}; // user specified CUDA compute stream. + // can be updated using: UpdateTensorRTProviderOptionsWithValue int trt_max_partition_iterations{1000}; // maximum iterations for TensorRT parser to get capability int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs size_t trt_max_workspace_size{1 << 30}; // maximum workspace size for TensorRT. @@ -78,6 +79,12 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for // the ONNX model containing the weights (applicable only when // the "trt_weight_stripped_engine_enable" option is enabled) + const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights + // (applicable only when the "trt_weight_stripped_engine_enable" + // option is enabled) + // can be updated using: UpdateTensorRTProviderOptionsWithValue + size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream" + // can be updated using: UpdateTensorRTProviderOptionsWithValue const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index c32e2a77e8453..17ae649e6f174 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -270,3 +270,8 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // - "0": Gemm FastMath mode is not enabled. [DEFAULT] // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; + +// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. +// Refer to MatMulNBits op schema for more details. +// If not provided, default is 4. +static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index b802a4e8271a7..3cfc0457c6234 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -194,9 +194,9 @@ } }, "node_modules/@socket.io/component-emitter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", - "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", "dev": true }, "node_modules/@szmarczak/http-timer": { @@ -236,9 +236,9 @@ "dev": true }, "node_modules/@types/cors": { - "version": "2.8.13", - "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.13.tgz", - "integrity": "sha512-RG8AStHlUiV5ysZQKq97copd2UmVYw3/pRMLefISZ3S1hK104Cwm7iLQ3fTKx+lsUH2CE8FlLaYeEA2LSeqYUA==", + "version": "2.8.17", + "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz", + "integrity": "sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA==", "dev": true, "dependencies": { "@types/node": "*" @@ -1086,9 +1086,9 @@ } }, "node_modules/engine.io": { - "version": "6.4.2", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", - "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", + "version": "6.5.5", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.5.5.tgz", + "integrity": "sha512-C5Pn8Wk+1vKBoHghJODM63yk8MvrO9EWZUfkAt5HAqIgPE4/8FF0PEGHXtEd40l223+cE5ABWuPzm38PHFXfMA==", "dev": true, "dependencies": { "@types/cookie": "^0.4.1", @@ -1099,17 +1099,17 @@ "cookie": "~0.4.1", "cors": "~2.8.5", "debug": "~4.3.1", - "engine.io-parser": "~5.0.3", - "ws": "~8.11.0" + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1" }, "engines": { - "node": ">=10.0.0" + "node": ">=10.2.0" } }, "node_modules/engine.io-parser": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", - "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", "dev": true, "engines": { "node": ">=10.0.0" @@ -3020,35 +3020,37 @@ } }, "node_modules/socket.io": { - "version": "4.6.1", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.6.1.tgz", - "integrity": "sha512-KMcaAi4l/8+xEjkRICl6ak8ySoxsYG+gG6/XfRCPJPQ/haCRIJBTL4wIl8YCsmtaBovcAXGLOShyVWQ/FG8GZA==", + "version": "4.7.5", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.7.5.tgz", + "integrity": "sha512-DmeAkF6cwM9jSfmp6Dr/5/mfMwb5Z5qRrSXLpo3Fq5SqyU8CMF15jIN4ZhfSwu35ksM1qmHZDQ/DK5XTccSTvA==", "dev": true, "dependencies": { "accepts": "~1.3.4", "base64id": "~2.0.0", + "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.4.1", + "engine.io": "~6.5.2", "socket.io-adapter": "~2.5.2", - "socket.io-parser": "~4.2.1" + "socket.io-parser": "~4.2.4" }, "engines": { - "node": ">=10.0.0" + "node": ">=10.2.0" } }, "node_modules/socket.io-adapter": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.2.tgz", - "integrity": "sha512-87C3LO/NOMc+eMcpcxUBebGjkpMDkNBS9tf7KJqcDsmL936EChtVva71Dw2q4tQcuVC+hAUy4an2NO/sYXmwRA==", + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.5.tgz", + "integrity": "sha512-eLDQas5dzPgOWCk9GuuJC2lBqItuhKI4uxGgo9aIV7MYbk2h9Q6uULEh8WBzThoI7l+qU9Ast9fVUmkqPP9wYg==", "dev": true, "dependencies": { - "ws": "~8.11.0" + "debug": "~4.3.4", + "ws": "~8.17.1" } }, "node_modules/socket.io-parser": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.3.tgz", - "integrity": "sha512-JMafRntWVO2DCJimKsRTh/wnqVvO4hrfwOqtO7f+uzwsQMuxO6VwImtYxaQ+ieoyshWOTJyV0fA21lccEXRPpQ==", + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", "dev": true, "dependencies": { "@socket.io/component-emitter": "~3.1.0", @@ -3449,16 +3451,16 @@ "dev": true }, "node_modules/ws": { - "version": "8.11.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", - "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "engines": { "node": ">=10.0.0" }, "peerDependencies": { "bufferutil": "^4.0.1", - "utf-8-validate": "^5.0.2" + "utf-8-validate": ">=5.0.2" }, "peerDependenciesMeta": { "bufferutil": { @@ -3648,9 +3650,9 @@ "dev": true }, "@socket.io/component-emitter": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz", - "integrity": "sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", "dev": true }, "@szmarczak/http-timer": { @@ -3687,9 +3689,9 @@ "dev": true }, "@types/cors": { - "version": "2.8.13", - "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.13.tgz", - "integrity": "sha512-RG8AStHlUiV5ysZQKq97copd2UmVYw3/pRMLefISZ3S1hK104Cwm7iLQ3fTKx+lsUH2CE8FlLaYeEA2LSeqYUA==", + "version": "2.8.17", + "resolved": "https://registry.npmjs.org/@types/cors/-/cors-2.8.17.tgz", + "integrity": "sha512-8CGDvrBj1zgo2qE+oS3pOCyYNqCPryMWY2bGfwA0dcfopWGgxs+78df0Rs3rc9THP4JkOhLsAa+15VdpAqkcUA==", "dev": true, "requires": { "@types/node": "*" @@ -4379,9 +4381,9 @@ } }, "engine.io": { - "version": "6.4.2", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", - "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", + "version": "6.5.5", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.5.5.tgz", + "integrity": "sha512-C5Pn8Wk+1vKBoHghJODM63yk8MvrO9EWZUfkAt5HAqIgPE4/8FF0PEGHXtEd40l223+cE5ABWuPzm38PHFXfMA==", "dev": true, "requires": { "@types/cookie": "^0.4.1", @@ -4392,14 +4394,14 @@ "cookie": "~0.4.1", "cors": "~2.8.5", "debug": "~4.3.1", - "engine.io-parser": "~5.0.3", - "ws": "~8.11.0" + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1" } }, "engine.io-parser": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.0.6.tgz", - "integrity": "sha512-tjuoZDMAdEhVnSFleYPCtdL2GXwVTGtNjoeJd9IhIG3C1xs9uwxqRNEu5WpnDZCaozwVlK/nuQhpodhXSIMaxw==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", "dev": true }, "ent": { @@ -5862,32 +5864,34 @@ "dev": true }, "socket.io": { - "version": "4.6.1", - "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.6.1.tgz", - "integrity": "sha512-KMcaAi4l/8+xEjkRICl6ak8ySoxsYG+gG6/XfRCPJPQ/haCRIJBTL4wIl8YCsmtaBovcAXGLOShyVWQ/FG8GZA==", + "version": "4.7.5", + "resolved": "https://registry.npmjs.org/socket.io/-/socket.io-4.7.5.tgz", + "integrity": "sha512-DmeAkF6cwM9jSfmp6Dr/5/mfMwb5Z5qRrSXLpo3Fq5SqyU8CMF15jIN4ZhfSwu35ksM1qmHZDQ/DK5XTccSTvA==", "dev": true, "requires": { "accepts": "~1.3.4", "base64id": "~2.0.0", + "cors": "~2.8.5", "debug": "~4.3.2", - "engine.io": "~6.4.1", + "engine.io": "~6.5.2", "socket.io-adapter": "~2.5.2", - "socket.io-parser": "~4.2.1" + "socket.io-parser": "~4.2.4" } }, "socket.io-adapter": { - "version": "2.5.2", - "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.2.tgz", - "integrity": "sha512-87C3LO/NOMc+eMcpcxUBebGjkpMDkNBS9tf7KJqcDsmL936EChtVva71Dw2q4tQcuVC+hAUy4an2NO/sYXmwRA==", + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/socket.io-adapter/-/socket.io-adapter-2.5.5.tgz", + "integrity": "sha512-eLDQas5dzPgOWCk9GuuJC2lBqItuhKI4uxGgo9aIV7MYbk2h9Q6uULEh8WBzThoI7l+qU9Ast9fVUmkqPP9wYg==", "dev": true, "requires": { - "ws": "~8.11.0" + "debug": "~4.3.4", + "ws": "~8.17.1" } }, "socket.io-parser": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.3.tgz", - "integrity": "sha512-JMafRntWVO2DCJimKsRTh/wnqVvO4hrfwOqtO7f+uzwsQMuxO6VwImtYxaQ+ieoyshWOTJyV0fA21lccEXRPpQ==", + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", "dev": true, "requires": { "@socket.io/component-emitter": "~3.1.0", @@ -6179,9 +6183,9 @@ "dev": true }, "ws": { - "version": "8.11.0", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz", - "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "requires": {} }, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index a5b9c84c63eb9..55292b35e1e38 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -147,6 +147,23 @@ constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_ } // namespace sparse_attention namespace attention { + +enum class AttentionBackend : int { + FLASH_ATTENTION = 1, + EFFICIENT_ATTENTION = 2, + TRT_FUSED_ATTENTION = 4, + CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention. + MATH = 16, // unfused kernel cannot be disabled right now. + + // The following kernels might be deprecated in the future. + TRT_FLASH_ATTENTION = 32, + TRT_CROSS_ATTENTION = 64, + TRT_CAUSAL_ATTENTION = 128, +}; + +// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled). +constexpr const char* kEnableAttentionKernelDebugInfo = "ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"; + // Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; @@ -157,6 +174,9 @@ constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATT // Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels. constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION"; +// Environment variable to enable or disable cuDNN flash attention. +constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION"; + // Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled). constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION"; @@ -166,11 +186,15 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF // Environment variable to enable or disable flash attention. Default is 0 (enabled). constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; -// Minimum sequence length to enable memory efficient attention in FP32. -constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; +// Minimum sequence length to perfer memory efficient attention when data type is float32 +constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32"; + +// Default value for minimum sequence length to enable memory efficient attention in FP32. +constexpr int kDefaultMinSeqLenForEfficientAttentionFp32 = 256; // Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV"; + // Default value for the above setting. constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index d9907f09121d0..cacd65313ebcc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -3,7 +3,6 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention.h" #include "contrib_ops/cuda/bert/bert_padding.h" @@ -40,36 +39,17 @@ REGISTER_KERNEL_TYPED(MLFloat16) template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) { - disable_fused_self_attention_ = - sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); + kernel_options_ = this->GetAttentionKernelOptions(); - enable_trt_flash_attention_ = - sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); - enable_fused_causal_attention_ = - sizeof(T) == 2 && - ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = - ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention(); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = - sizeof(T) != 2 || - onnxruntime::ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); + + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); } template @@ -134,7 +114,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads, parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. - if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } // Allocate buffers @@ -220,7 +200,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == past && nullptr == present && (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); if (use_memory_efficient_attention) { @@ -231,6 +211,20 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_memory_efficient_attention = false; #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + + debug_info.Print("Attention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + cublasHandle_t cublas = GetCublasHandle(context); typedef typename ToCudaType::MappedType CudaT; @@ -268,7 +262,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_fused_cross_attention, use_memory_efficient_attention); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); - ; typedef typename ToCudaType::MappedType CudaT; AttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index acafb379d713f..0c7d3621f95ef 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -8,6 +8,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -27,9 +28,10 @@ class Attention final : public CudaKernel, public AttentionBase { bool enable_trt_flash_attention_; bool enable_fused_causal_attention_; bool disable_memory_efficient_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; + + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc new file mode 100644 index 0000000000000..28a095e68131e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#include +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/shared_library/provider_api.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" + +using namespace onnxruntime::contrib::attention; + +namespace onnxruntime { +void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { + if (value > 0) { + use_flash_attention_ = (value & static_cast(AttentionBackend::FLASH_ATTENTION)) > 0; + use_efficient_attention_ = (value & static_cast(AttentionBackend::EFFICIENT_ATTENTION)) > 0; + use_trt_fused_attention_ = (value & static_cast(AttentionBackend::TRT_FUSED_ATTENTION)) > 0; + use_cudnn_flash_attention_ = (value & static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0; + use_unfused_ = (value & static_cast(AttentionBackend::MATH)) > 0; + use_trt_flash_attention_ = (value & static_cast(AttentionBackend::TRT_FLASH_ATTENTION)) > 0; + use_trt_cross_attention_ = (value & static_cast(AttentionBackend::TRT_CROSS_ATTENTION)) > 0; + use_trt_causal_attention_ = (value & static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0; + } else { + use_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFlashAttention, false); + use_efficient_attention_ = !ParseEnvironmentVariableWithDefault(kDisableMemoryEfficientAttention, false); + use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedSelfAttention, false); + use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, false); + use_unfused_ = true; + use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableTrtFlashAttention, false); + use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedCrossAttention, false); + use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault(kEnableFusedCausalAttention, false); + } + + enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault(kEnableAttentionKernelDebugInfo, false); + + // When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing. + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + kMinSeqLenForFlashAttentionPackedQKV, + value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV); + + min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault( + kMinSeqLenForEfficientAttentionFp32, + value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32); + + if (use_build_flag) { + // Some kernels can be disabled at build time. If they are disabled, we should not use them. +#ifndef USE_FLASH_ATTENTION + use_flash_attention_ = false; +#endif + +#ifndef USE_MEMORY_EFFICIENT_ATTENTION + use_efficient_attention_ = false; +#endif + } +} + +void AttentionKernelOptions::InitializeOnce( + int sdpa_kernel, bool use_build_flag) { + std::call_once(this->initialize_once_flag_, [&]() { + this->Initialize(sdpa_kernel, use_build_flag); + if (this->enable_kernel_debug_info_) { + this->Print(); + } + }); +} + +void AttentionKernelOptions::Print() const { + std::stringstream sstream; + sstream << "AttentionKernelOptions:"; + sstream << " FLASH_ATTENTION=" << int(use_flash_attention_); + sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_); + sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_); + sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_); + sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_); + sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_); + sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_); + sstream << " MATH=" << int(use_unfused_); + + if (!use_unfused_) { + sstream << std::endl + << "Warning: Unfused kernel cannot be disabled right now. MATH=0 is ignored."; + } + + // Output text in Cyan color to make it easier to spot + std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl; +} + +// Classify the kernel used in TRT fused runner. +void AttentionKernelDebugInfo::SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length) { + if (causal) { + use_trt_causal_attention = true; + } else if (enable_trt_flash_attention && sequence_length >= contrib::cuda::kMinSequenceLengthFlashAttention) { + use_trt_flash_attention = true; + } else { + use_trt_fused_attention = true; + } +} + +void AttentionKernelDebugInfo::Print(const char* operator_name, + const std::string& node_name, + bool is_float16, + bool is_bfloat16) const { + std::stringstream sstream; + sstream << "Operator=" << operator_name; + + if (node_name.length() > 0) { + sstream << " Node=" << node_name; + } + + if (is_bfloat16) { + sstream << " DataType=bf16"; + } else if (is_float16) { + sstream << " DataType=fp16"; + } else { + sstream << " DataType=fp32"; + } + + if (use_flash_attention.has_value() && use_flash_attention.value()) { + sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value()); + } + + if (use_efficient_attention.has_value() && use_efficient_attention.value()) { + sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value()); + } + + if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { + sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value()); + } + + if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) { + sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value()); + } + + if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) { + sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value()); + } + + if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) { + sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value()); + } + + if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) { + sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value()); + } + + bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) || + (use_efficient_attention.has_value() && use_efficient_attention.value()) || + (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) || + (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) || + (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) || + (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) || + (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()); + + // Fall back to unfused when no fused kernel is enabled. + if (!use_fused) { + sstream << " MATH=1"; + } + + // Output text in Cyan color to make it easier to spot. + std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl; +} + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h new file mode 100644 index 0000000000000..bd7df5f490c76 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include + +namespace onnxruntime { +struct AttentionKernelDebugInfo { + std::optional use_flash_attention = std::nullopt; + std::optional use_efficient_attention = std::nullopt; + std::optional use_trt_fused_attention = std::nullopt; + std::optional use_cudnn_flash_attention = std::nullopt; + std::optional use_trt_flash_attention = std::nullopt; + std::optional use_trt_cross_attention = std::nullopt; + std::optional use_trt_causal_attention = std::nullopt; + void SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length); + void Print(const char* operator_name, const std::string& node_name, bool is_float16, bool is_bfloat16) const; +}; + +class AttentionKernelOptions { + public: + void InitializeOnce(int sdpa_kernel, bool use_build_flag); + + bool UseFlashAttention() const { return use_flash_attention_; } + bool UseEfficientAttention() const { return use_efficient_attention_; } + bool UseTrtFusedAttention() const { return use_trt_fused_attention_; } + bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; } + bool UseUnfusedAttention() const { return use_unfused_; } + bool UseTrtFlashAttention() const { return use_trt_flash_attention_; } + bool UseTrtCrossAttention() const { return use_trt_cross_attention_; } + bool UseTrtCausalAttention() const { return use_trt_causal_attention_; } + + bool AllowDebugInfo() const { return enable_kernel_debug_info_; } + + int MinSeqLenForFlashAttentionPackedQkv() const { return min_seq_len_for_flash_attention_packed_qkv_; } + int MinSeqLenForEfficientAttentionFp32() const { return min_seq_len_for_efficient_attention_fp32_; } + + protected: + void Print() const; + + void Initialize(int value, bool use_build_flag); + + private: + bool use_flash_attention_{true}; + bool use_efficient_attention_{true}; + bool use_trt_fused_attention_{true}; + bool use_cudnn_flash_attention_{false}; + bool use_unfused_{true}; + + bool use_trt_flash_attention_{true}; + bool use_trt_cross_attention_{true}; + + // Causal attention is disabled by default in #14732. + bool use_trt_causal_attention_{false}; + + bool enable_kernel_debug_info_{false}; + + int min_seq_len_for_flash_attention_packed_qkv_{0}; + + int min_seq_len_for_efficient_attention_fp32_{0}; + + std::once_flag initialize_once_flag_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 3b6ad238cc826..797f9b0a1ea47 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -52,20 +52,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); -#else - disable_flash_attention_ = true; -#endif + kernel_options_ = this->GetAttentionKernelOptions(); + + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION // Memory efficient attention only supports float and float16, not bfloat16. - disable_memory_efficient_attention_ = std::is_same::value || - ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = std::is_same::value || !kernel_options_->UseEfficientAttention(); + if (!disable_flash_attention_) { zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); } @@ -161,7 +154,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size); if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -201,6 +194,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto unpacked_qkv_buffer = GetScratchBuffer(0, context->GetComputeStream()); #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + + debug_info.Print("GroupQueryAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + // seqlens_k buffer size_t seqlens_k_bytes = 0; seqlens_k_bytes = sizeof(int) * parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 15573ece166fc..4ff5b0a59f021 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -6,6 +6,7 @@ #include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -32,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel { bool disable_memory_efficient_attention_; static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ba8b00df07e06..b96140f3897f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/providers/cuda/cuda_common.h" -#include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/multihead_attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" @@ -47,31 +46,16 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); - disable_fused_self_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); + kernel_options_ = this->GetAttentionKernelOptions(); - enable_trt_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); - disable_fused_cross_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false); + disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention(); // Allocate cache buffers constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast(kCumulatedSequenceLengthCacheMaxBatchSize) + 1); @@ -155,7 +139,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. if (use_flash_attention && key == nullptr && value == nullptr && - parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } // Allocate buffers @@ -229,9 +213,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } #if USE_MEMORY_EFFICIENT_ATTENTION + int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32(); bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 - parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32 || - parameters.kv_sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32; + parameters.sequence_length >= length_threshold || + parameters.kv_sequence_length >= length_threshold; bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; @@ -249,6 +234,21 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_memory_efficient_attention = false; #endif + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_fp16_runner_ != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + + debug_info.Print("MultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. bool no_qkv_workspace = nullptr == value && diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 86a32c92ce003..26e38dbad9fd7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -8,6 +8,7 @@ #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -31,12 +32,12 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; + const AttentionKernelOptions* kernel_options_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 0146cce30c7d1..a1149ddbf99f5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -33,12 +33,11 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) template -TrtFusedAttention::TrtFusedAttention() { - disable_fused_runner_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); - - enable_trt_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); +TrtFusedAttention::TrtFusedAttention(const OpKernelInfo& info) + : CudaKernel(info) { + kernel_options_ = this->GetAttentionKernelOptions(); + disable_fused_runner_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); + enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); } template @@ -86,7 +85,8 @@ template class TrtFusedAttention; template class TrtFusedAttention; template -PackedAttention::PackedAttention(const OpKernelInfo& info) : TrtFusedAttention(), CudaKernel(info) { +PackedAttention::PackedAttention(const OpKernelInfo& info) + : TrtFusedAttention(info) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); @@ -268,7 +268,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* relative_position_bias = context->Input(5); PackedAttentionParameters parameters; - parameters.use_tf32 = UseTF32(); + parameters.use_tf32 = this->UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -295,6 +295,19 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { } #endif + if (this->kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length); + } + + debug_info.Print("PackedAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + typedef typename ToCudaType::MappedType CudaT; CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -313,7 +326,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, this->UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool no_qkv_workspace = false; // need workspace to add bias @@ -341,7 +354,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_memory_efficient_attention = use_memory_efficient_attention; - return QkvToContext(device_prop, cublas, Stream(context), parameters, data); + return QkvToContext(device_prop, cublas, this->Stream(context), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index f00c112fc73d2..67b420764169a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -9,6 +9,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { @@ -17,14 +18,16 @@ namespace cuda { using namespace onnxruntime::cuda; template -class TrtFusedAttention { +class TrtFusedAttention : public CudaKernel { public: - TrtFusedAttention(); + TrtFusedAttention(const OpKernelInfo& info); protected: MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; protected: + const AttentionKernelOptions* kernel_options_; + bool disable_fused_runner_; bool enable_trt_flash_attention_; mutable std::unique_ptr fused_fp16_runner_; @@ -32,7 +35,7 @@ class TrtFusedAttention { }; template -class PackedAttention final : public TrtFusedAttention, public CudaKernel { +class PackedAttention final : public TrtFusedAttention { public: PackedAttention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 3fbbafc01254e..53e96fc732a33 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -35,30 +35,16 @@ REGISTER_KERNEL_TYPED(MLFloat16) template PackedMultiHeadAttention::PackedMultiHeadAttention(const OpKernelInfo& info) - : TrtFusedAttention(), CudaKernel(info) { + : TrtFusedAttention(info) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); scale_ = info.GetAttrOrDefault("scale", 0.0f); -#if USE_FLASH_ATTENTION - disable_flash_attention_ = sizeof(T) != 2 || onnxruntime::ParseEnvironmentVariableWithDefault( - attention::kDisableFlashAttention, false); - min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( - attention::kMinSeqLenForFlashAttentionPackedQKV, - attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); -#else - disable_flash_attention_ = true; - min_seq_len_for_flash_attention_packed_qkv_ = 0; -#endif + disable_flash_attention_ = sizeof(T) != 2 || !this->kernel_options_->UseFlashAttention(); -#if USE_MEMORY_EFFICIENT_ATTENTION - disable_memory_efficient_attention_ = onnxruntime::ParseEnvironmentVariableWithDefault( - attention::kDisableMemoryEfficientAttention, false); -#else - disable_memory_efficient_attention_ = true; -#endif + disable_memory_efficient_attention_ = !this->kernel_options_->UseEfficientAttention(); } template @@ -228,7 +214,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* relative_position_bias = context->Input(6); PackedAttentionParameters parameters; - parameters.use_tf32 = UseTF32(); + parameters.use_tf32 = this->UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(), key, value, @@ -255,7 +241,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. if (use_flash_attention && key == nullptr && value == nullptr && - parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + parameters.sequence_length < this->kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } } @@ -271,11 +257,25 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; use_memory_efficient_attention = is_good_for_rpb && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } #endif + if (this->kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_runner != nullptr) { + debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length); + } + + debug_info.Print("PackedMultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + typedef typename ToCudaType::MappedType CudaT; cublasHandle_t cublas = this->GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h index e30c603dc30aa..9b52a70fc6181 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h @@ -4,13 +4,14 @@ #pragma once #include "contrib_ops/cuda/bert/packed_attention.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" namespace onnxruntime { namespace contrib { namespace cuda { template -class PackedMultiHeadAttention final : public TrtFusedAttention, public CudaKernel { +class PackedMultiHeadAttention final : public TrtFusedAttention { public: PackedMultiHeadAttention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; @@ -32,7 +33,6 @@ class PackedMultiHeadAttention final : public TrtFusedAttention, public CudaK bool disable_memory_efficient_attention_; bool disable_flash_attention_; - int min_seq_len_for_flash_attention_packed_qkv_; }; } // namespace cuda diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4298551aec412..7da65f18ccacb 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -13,6 +13,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) @@ -132,12 +133,12 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); - rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; case TransformerLevel::Level2: rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; @@ -187,7 +188,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -287,6 +289,10 @@ InlinedVector> GenerateTransformers( onnxruntime::kJsExecutionProvider}; const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -300,7 +306,10 @@ InlinedVector> GenerateTransformers( if (!qdq_is_int8_allowed) { transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); } - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + SatApplyContextVariant{}, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -409,7 +418,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -423,12 +433,18 @@ InlinedVector> GenerateTransformersForMinimalB const bool qdq_is_int8_allowed = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, QDQIsInt8Allowed() ? "1" : "0") == "1"; - + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; if (!disable_quant_qdq) { - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + apply_context, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc index 7417212c570c8..e756ffe78a289 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc @@ -13,13 +13,15 @@ namespace onnxruntime { bool ReluQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || + !graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { return false; } // if Relu is followed by QuantizeLinear, it can be fused into QuantizeLinear potentially const auto& next_node = *node.OutputNodesBegin(); - if (!QDQ::MatchQNode(next_node)) { + if (!graph_utils::IsSupportedProvider(next_node, {kCpuExecutionProvider}) || + !QDQ::MatchQNode(next_node)) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 3497ea4c85523..74fecb0427e14 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -2,10 +2,12 @@ // Licensed under the MIT License. #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" - #include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/optimizer/initializer.h" #include "core/graph/node_attr_utils.h" #include "core/framework/tensorprotoutils.h" +#include "core/mlas/inc/mlas_q4.h" + namespace onnxruntime { namespace QDQ { @@ -273,6 +275,175 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } +DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) + : accuracy_level_{accuracy_level}, + domain_{kMSDomain}, + op_type_{"MatMulNBits"}, + value_moves_{[]() { + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; + return std::vector{ + MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), + MoveAll(target, ArgType::kOutput)}; + }()}, + intra_op_thread_pool_{intra_op_thread_pool} { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); +} + +NodeAttributes +DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) const { + NodeAttributes extra_attributes; + + const auto* dq_node = runtime_state.selected_nodes.Input(0); + auto& attrs = dq_node->GetAttributes(); + const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); + + utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); + // currently only 4bits is supported. In the future, derive bits from DQ's weight type. + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); + + return extra_attributes; +} + +Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, + const NodesToOptimize& selected_nodes, + Node& replacement_node) const { + const auto* dq_node = selected_nodes.Input(0); + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + const auto& attrs = dq_node->GetAttributes(); + + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto); + graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto); + if (zp_arg) { + graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); + } + + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = attrs.at("block_size").i(); + auto quant_num = (K + block_size - 1) / block_size; + auto blob_bytes = (block_size + 1) / 2; + + // Unfortunately iterating the source data is complicated, the data maybe in + // external file, a raw buffer, or a repeated field depending on the data + // type. UnpackTensor() already contains some of these logic and is closest + // to what we need. But it does not handle external data. + Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); + std::optional zp_src; + Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(weight_arg->Name() + "_T"), + std::vector{N, quant_num, blob_bytes}); + Initializer scale_dst(static_cast(scale_src.data_type()), + graph.GenerateNodeArgName(scale_arg->Name() + "_T"), + std::vector{N * quant_num}); + std::optional zp_dst; + + if (zp_tensor_proto) { + zp_src.emplace(*zp_tensor_proto, graph.ModelPath()); + zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)}); + } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), + std::vector{N * ((quant_num + 1) / 2)}); + } + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst ? zp_dst->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } else { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst ? zp_dst->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } + } else { + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst ? zp_dst->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + + } else { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst ? zp_dst->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } + } + + ONNX_NAMESPACE::TensorProto weight_T_tp; + ONNX_NAMESPACE::TensorProto scale_T_tp; + std::optional zp_T_tp; + + // TODO(fajin): external_data to memory location to avoid arena allocation + // https://github.com/microsoft/onnxruntime/pull/12465 + weight_dst.ToProto(weight_T_tp); + scale_dst.ToProto(scale_T_tp); + if (zp_dst) { + zp_T_tp.emplace(); + zp_dst->ToProto(zp_T_tp.value()); + } + + auto& input_defs = replacement_node.MutableInputDefs(); + input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + + if (zp_T_tp) { + input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); + replacement_node.MutableInputArgsCount().push_back(1); + } + + return Status::OK(); +} + static std::vector GetGemmMoveInfo(bool does_q_node_exist) { NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0}; NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1}; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 8179a030508a5..47821619db65a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -3,7 +3,12 @@ #pragma once +#include +#include +#include + #include "core/optimizer/selectors_actions/actions.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -76,6 +81,30 @@ struct MatMulReplaceWithQLinear : public Action { BinaryReplaceWithQLinear qlinear_matmul_replacer_; }; +// used together with DQMatMulNodeGroupSelector, which does the sanity check +struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { + DQMatMulToMatMulNBitsAction(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool); + + private: + std::string OpType(const RuntimeState&) const override { return op_type_; } + + std::string Domain(const RuntimeState&) const override { return domain_; } + + NodeAttributes ExtraAttributes(const RuntimeState&) const override; + + std::vector ValueMoves(const RuntimeState&) const override { return value_moves_; } + + // transpose initializers, and add to the MatMulNBits inputs + Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const override; + + const int64_t accuracy_level_; + const std::string domain_; + const std::string op_type_; + const std::vector value_moves_; + concurrency::ThreadPool* intra_op_thread_pool_; +}; + struct GemmReplaceWithQuant : public Action { GemmReplaceWithQuant(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 80ead8f8c68d6..17e66a3953b97 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -228,6 +228,30 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #endif } +void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { + // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. + // DQ's weight is int4/uint4. DQ's scale is float/float16. + // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. + const std::string action_name{"DQMatMulToMatMulNBits"}; + + std::unique_ptr action = + std::make_unique(qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); + +#if !defined(ORT_MINIMAL_BUILD) + std::unique_ptr selector = std::make_unique(); + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"MatMul", {}}}, + std::move(selector), + std::move(action)); + +#else + qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); +#endif +} + void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 to 5 nodes. 0=DQ A, 1=DQ B, 2=DQ C(optional), 3=Gemm, 4=Q Y(optional) // Replace with QGemm @@ -271,7 +295,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { +SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -283,17 +309,22 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed); GemmQDQRules(qdq_selector_action_registry); WhereQDQRules(qdq_selector_action_registry); + DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer( - bool is_int8_allowed, const SatApplyContextVariant& apply_context) +QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 1780923f3f273..ba636f76d1900 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -5,6 +5,7 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" #include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -21,7 +22,10 @@ Transformer that fuses QDQ and fp32 ops into quantized ops. */ class QDQSelectorActionTransformer : public SelectorActionTransformer { public: - QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}); + QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context = {}, + int64_t qdq_matmulnbits_accuracy_level = 4, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 09705f61c82ce..6e93445c7c5c7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -414,6 +414,91 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } +bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + // Should not have any Q nodes + if (!q_nodes.empty()) { + return false; + } + + const auto& graph = graph_viewer.GetGraph(); + + // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output + if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + return false; + } + + // DQ must be MatMul's the second input + if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + return false; + } + + // DQ weight/zero points types are int4/uint4, scales/output types are float or float16 + const auto* weight_arg = dq_nodes[0]->InputDefs()[0]; + const auto* scale_arg = dq_nodes[0]->InputDefs()[1]; + const auto* zero_point_arg = dq_nodes[0]->InputDefs().size() == 3 ? dq_nodes[0]->InputDefs()[2] : nullptr; + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && + dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { + return false; + } + + if (!Is4BitIntType(dt_weight)) { + return false; + } + + // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 + const auto& dq_attrs = dq_nodes[0]->GetAttributes(); + if (const auto a_iter = dq_attrs.find("axis"); + a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return false; + } + + const auto a_iter = dq_attrs.find("block_size"); + if (a_iter == dq_attrs.end()) { + return false; + } + + auto block_size = a_iter->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) { + return false; + } + + // weight, scale and zero points (if exists) must be constants + const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); + const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); + const auto* zp_tensor_proto = zero_point_arg ? graph.GetConstantInitializer(zero_point_arg->Name(), true) : nullptr; + + if (!weight_tensor_proto || !scale_tensor_proto) { + return false; + } + + if (zero_point_arg && !zp_tensor_proto) { + return false; + } + + // weight, scale and zero points (if exists) must have the rank 2 + if (weight_tensor_proto->dims_size() != 2 || + scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return false; + } + + // check weight, scale and zero points (if exists) shapes + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && + (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return false; + } + + return true; +} + bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 1a2a620acb480..491a15b62cb03 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -204,6 +204,14 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { bool allow_4bit_; }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulNodeGroupSelector : public NodeGroupSelector { + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { @@ -358,6 +366,13 @@ class MatMulSelector : public BaseSelector { allow_16bit, allow_4bit)) {} }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulToMatMulNBitsSelector : public BaseSelector { + public: + explicit DQMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) + : BaseSelector(std::make_unique(), compatible_providers) {} +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.cc b/onnxruntime/core/optimizer/selectors_actions/actions.cc index c8d5acbf66b78..bb4033afedc49 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.cc +++ b/onnxruntime/core/optimizer/selectors_actions/actions.cc @@ -102,12 +102,14 @@ static Status CreateReplacementNode(Graph& graph, Status ReplaceWithNew::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { const RuntimeState runtime_state{graph, selected_nodes}; + Node* replacement{}; ORT_RETURN_IF_ERROR(CreateReplacementNode(graph, selected_nodes, OpType(runtime_state), Domain(runtime_state), ExtraAttributes(runtime_state), ValueMoves(runtime_state), - /* only_update_dest_definitions */ false, nullptr)); + /* only_update_dest_definitions */ false, &replacement)); + ORT_RETURN_IF_ERROR(ProcessNewNode(graph, selected_nodes, *replacement)); return node_remover_.Run(graph, selected_nodes); } diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index 9384bfa7027cd..465ae38565b15 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -158,6 +158,12 @@ struct ReplaceWithNew : public Action { // specifies how the inputs and outputs for the replaced nodes are moved to the new node virtual std::vector ValueMoves(const RuntimeState&) const = 0; + // For the changes that cannot be done by simply moving node args around, use this method to make + // additional changes to the new node and the graph. e.g., DQMatMulToMatMulNBitsAction transposes + // the second weight of MatMul ops and create new node args. + // Note: This method is only used in Run(), but not in RunForSave(). + virtual Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const { return Status::OK(); } + RemoveNodes node_remover_; }; @@ -187,5 +193,4 @@ struct ReplaceWithNewFixed : public ReplaceWithNew { const NodeAttributes extra_attrs_; const std::vector value_moves_; }; - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index b8ebbd05a2a20..e1f148fa93e23 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -50,8 +50,8 @@ bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, } } -bool IsInputSupported(const Node& node, const NodeArg& input, - const OpBuilderInputParams& input_params, const logging::Logger& logger) { +bool IsInputSupported(const Node& node, const NodeArg& input, const OpBuilderInputParams& input_params, + const logging::Logger& logger, bool allow_empty_input) { if (!input.Exists()) { // optional input that is not provided return true; @@ -84,16 +84,10 @@ bool IsInputSupported(const Node& node, const NodeArg& input, return false; } - if (dim == 0) { - if (node.OpType() == "Resize" && &input == node.InputDefs()[1]) { - // one special case. Resize 'roi' input was originally a required input but is rarely used. - // ROI is not supported in the CoreML implementation so we will ignore the value, but is often added - // (at least in the unit tests) as an initializer with shape {0}. - } else { - LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name - << ", shape: " << Shape2String(shape); - return false; - } + if (dim == 0 && !allow_empty_input) { + LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name + << ", shape: " << Shape2String(shape); + return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h index 300de2dedd122..0acaa0dd8a4a3 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.h +++ b/onnxruntime/core/providers/coreml/builders/helper.h @@ -30,7 +30,8 @@ OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, const IOpBuilder* GetOpBuilder(const Node& node); bool IsInputSupported(const Node& node, const NodeArg& node_arg, const OpBuilderInputParams& input_params, - const logging::Logger& logger); + const logging::Logger& logger, + bool allow_empty_input = false); bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 83a572f4b60fa..2cae85a0a1c8d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -74,7 +74,7 @@ bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { for (const auto* input : node.InputDefs()) { - if (!IsInputSupported(node, *input, input_params, logger)) { + if (!IsInputSupported(node, *input, input_params, logger, allow_empty_tensor_as_input_)) { return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 4a23640d0f34c..071008520fbdc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -28,6 +28,10 @@ class BaseOpBuilder : public IOpBuilder { void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {} protected: + explicit BaseOpBuilder(bool allow_empty_tensor_as_input = false) + : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { + } + // currently we only support float static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params, const logging::Logger& logger); @@ -50,6 +54,8 @@ class BaseOpBuilder : public IOpBuilder { virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const = 0; + + const bool allow_empty_tensor_as_input_; // some operators can handle ignoring an empty tensor as input }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 3400f09b4056f..65b5c17f2c6a6 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include +#include #include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/providers/common.h" +#include "core/providers/utils.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -18,6 +20,11 @@ namespace onnxruntime { namespace coreml { class ResizeOpBuilder : public BaseOpBuilder { + public: + // allow roi and scales potentially being empty inputs that are ignored during processing + ResizeOpBuilder() : BaseOpBuilder(/*allow empty inputs*/ true) {} + + private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -29,196 +36,382 @@ class ResizeOpBuilder : public BaseOpBuilder { // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing // We only support Resize opset 11+ here int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } + + bool SupportsMLProgram() const override { return true; } }; namespace { -bool GetResizeScales(const InitializedTensorSet& initializers, - const Node& node, std::vector& scales, - const logging::Logger&) { +std::vector GetAxes(const NodeAttrHelper& helper, size_t input_rank) { + auto axes = helper.Get("axes", std::vector{}); + if (axes.empty()) { + axes.resize(input_rank); + std::iota(axes.begin(), axes.end(), 0); + } else { + for (auto& value : axes) { + if (value < 0) { + value = HandleNegativeAxis(value, input_rank); + } + } + } + + return axes; +} + +bool GetValidatedResizeScales(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& input_shape, + const std::vector& axes, + std::vector& scales, + const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 3) + int64_t input_rank = input_shape.size(); + + if (input_shape[input_rank - 2] == -1 || input_shape[input_rank - 1] == -1) { + LOGS(logger, VERBOSE) << "Resize with 'scales' requires the H and W dimensions to have fixed values"; return false; + } - const auto& scales_tensor = *initializers.at(input_defs[2]->Name()); - if (scales_tensor.dims_size() != 1 || scales_tensor.dims()[0] != 4) + const auto* scales_tensor = graph_viewer.GetConstantInitializer(input_defs[2]->Name()); + if (!scales_tensor) { + LOGS(logger, VERBOSE) << "Resize 'scales' input must be a constant initializer"; return false; - Initializer unpacked_tensor(scales_tensor); + } + + Initializer unpacked_tensor(*scales_tensor); auto scales_data = unpacked_tensor.DataAsSpan(); - scales = std::vector{scales_data.begin(), scales_data.end()}; + scales.assign(scales_data.begin(), scales_data.end()); + + for (size_t idx = 0, end = axes.size(); idx < end; ++idx) { + auto axis = axes[idx]; + auto scale = scales[idx]; + if (axis < (input_rank - 2) && scale != 1.0f) { + LOGS(logger, VERBOSE) << "Resize only supports resizing the last two axes. Scale of axis " << axis << " is " + << scale; + return false; + } + } + return true; } -bool GetResizeOutputSizes(const InitializedTensorSet& initializers, - const Node& node, std::vector& sizes, - const logging::Logger&) { +bool GetValidatedResizeSizes(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& input_shape, + const std::vector& axes, + std::vector& sizes, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 4) - return false; + int64_t input_rank = input_shape.size(); - const auto& sizes_tensor = *initializers.at(input_defs[3]->Name()); - if (sizes_tensor.dims_size() != 1 || sizes_tensor.dims()[0] != 4) + const auto* sizes_tensor = graph_viewer.GetConstantInitializer(input_defs[3]->Name()); + if (!sizes_tensor) { + LOGS(logger, VERBOSE) << "Resize 'sizes' input must be a constant initializer"; return false; - Initializer unpacked_tensor(sizes_tensor); + } + + Initializer unpacked_tensor(*sizes_tensor); auto sizes_data = unpacked_tensor.DataAsSpan(); - sizes = std::vector(sizes_data.begin(), sizes_data.end()); + sizes.assign(sizes_data.begin(), sizes_data.end()); + + for (size_t idx = 0, end = axes.size(); idx < end; ++idx) { + auto axis = axes[idx]; + auto cur_size = input_shape[idx]; + auto new_size = sizes[idx]; + if (axis < (input_rank - 2) && cur_size != new_size) { + LOGS(logger, VERBOSE) << "Resize only supports resizing the last two axes. Input rank: " << input_rank + << " Change to size of axis " << axis << " from " << cur_size << " to " << new_size; + return false; + } + } + return true; } } // namespace void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - // We don't really use ROI here, so add it to skipped list if it's an initializer tensor - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI - model_builder.AddInputToSkip(node.InputDefs()[1]->Name()); // ROI - - // We will still add scales to the skipped list even sizes are present - // since there is no use of it, we will not process it later - model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // scales - model_builder.AddInputToSkip(node.InputDefs()[2]->Name()); // scales - - if (node.InputDefs().size() > 3) { - model_builder.AddInitializerToSkip(node.InputDefs()[3]->Name()); // sizes - model_builder.AddInputToSkip(node.InputDefs()[3]->Name()); // sizes + const auto& input_defs = node.InputDefs(); + + // In Resize-11 both roi and scales were required even if you were using sizes. + // https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11 + // From Resize-13 on they're all optional. + // + // We don't support roi so would never take a node with meaningful roi input. The roi input can however be provided + // and is ignored unless coordinate_transformation_mode is set to 'tf_crop_and_resize'. + // e.g. our unit tests tend to always provide an empty tensor as roi input instead of as a missing optional input. + // Due to this we always call AddInputToSkip on the roi input. + // + // We require the sizes or scales input to be a constant initializers to take the node (i.e. they won't be an input + // to the CoreML model for the partition, so calling AddInputToSkip isn't relevant). + // Individual values from scales and sizes are added directly to the layer, so we won't use the initializer. + // + // That leaves an edge case for Resize-11 where scales could have been provided as an empty input tensor but + // we're using a constant initializer for sizes. In this case AddInputToSkip needs to be called for the scales input. + + model_builder.AddInitializerToSkip(input_defs[1]->Name()); // roi + model_builder.AddInputToSkip(input_defs[1]->Name()); + + if (input_defs[2]->Exists()) { + model_builder.AddInitializerToSkip(input_defs[2]->Name()); // scales + } + + if (input_defs.size() > 3 && input_defs[3]->Exists()) { + model_builder.AddInitializerToSkip(input_defs[3]->Name()); // sizes + + if (node.SinceVersion() < 13) { + model_builder.AddInputToSkip(input_defs[2]->Name()); // skip the unused scales input + } } } -Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, +Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); + const auto input_defs = node.InputDefs(); + const auto output_defs = node.OutputDefs(); + const auto& graph_viewer = model_builder.GetGraphViewer(); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Error getting input shape"); + size_t input_rank = input_shape.size(); + + // we know we have either a scales or sizes input so this is safe. + // check for sizes first. this handles Resize-11 where scales was a required input but sizes were used if provided. + bool using_sizes = input_defs.size() >= 4 && input_defs[3]->Exists(); + bool using_scales = !using_sizes; - auto* coreml_upsample = layer->mutable_upsample(); NodeAttrHelper helper(node); - const auto mode = helper.Get("mode", "nearest"); - if (mode == "linear") { - coreml_upsample->set_mode(COREML_SPEC::UpsampleLayerParams_InterpolationMode_BILINEAR); - } else { // we already checked the mode must be NN or Bilinear in IsOpSupportedImpl - coreml_upsample->set_mode(COREML_SPEC::UpsampleLayerParams_InterpolationMode_NN); + const auto& mode = helper.Get("mode", "nearest"); + bool is_nearest = mode == "nearest"; + bool is_linear = !is_nearest; + + auto axes = GetAxes(helper, input_rank); + std::vector output_scales; + std::vector output_sizes; + size_t num_scales = 0; + size_t num_sizes = 0; + + if (using_scales) { + ORT_RETURN_IF_NOT(GetValidatedResizeScales(graph_viewer, node, input_shape, axes, output_scales, logger), + "Error getting validated scales"); + num_scales = output_scales.size(); + + // special case linear downsample. + // the CoreML implementation seems to be flaky and gives different outputs on different OS versions. + // use bilinear_resize instead. we check in IsOpSupportedImpl that the downsample input is evenly + // divisible by the output size so there's no rounding involved. + if (is_linear && (output_scales[num_scales - 1] < 1.f || output_scales[num_scales - 2] < 1.f)) { + using_scales = false; + using_sizes = true; + num_sizes = num_scales; + output_sizes = input_shape; + // only the last two dims have their size changed + output_sizes[input_rank - 2] = static_cast(input_shape[input_rank - 2] * output_scales[num_scales - 2]); + output_sizes[input_rank - 1] = static_cast(input_shape[input_rank - 1] * output_scales[num_scales - 1]); + } + } else { + ORT_RETURN_IF_NOT(GetValidatedResizeSizes(graph_viewer, node, input_shape, axes, output_sizes, logger), + "Error getting validated sizes"); + num_sizes = output_sizes.size(); } - const auto& input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - - if (input_defs.size() >= 3 && input_defs[2]->Exists()) { // use scales - std::vector scales; - ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); - coreml_upsample->add_scalingfactor(static_cast(scales[2])); - coreml_upsample->add_scalingfactor(static_cast(scales[3])); - } else { // we already checked number of inputs in IsOpSupportedImpl - std::vector input_shape; - ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Error getting input shape"); - std::vector output_sizes; - ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), - "Error getting resize output_sizes"); - coreml_upsample->add_scalingfactor(static_cast(output_sizes[2] / input_shape[2])); - coreml_upsample->add_scalingfactor(static_cast(output_sizes[3] / input_shape[3])); - } +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; // NOLINT + + std::string_view coreml_op_type; + if (using_scales) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.upsample_bilinear + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.upsample_nearest_neighbor + coreml_op_type = is_linear ? "upsample_bilinear" : "upsample_nearest_neighbor"; + } else { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.resize_bilinear + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.resize_nearest_neighbor + coreml_op_type = is_linear ? "resize_bilinear" : "resize_nearest_neighbor"; + } + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + + std::string coord_trans_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); + + if (using_scales) { + float scale_height = output_scales[num_scales - 2]; + float scale_width = output_scales[num_scales - 1]; + AddOperationInput(*op, "scale_factor_height", + model_builder.AddScalarConstant(coreml_op_type, "scale_factor_height", scale_height)); + AddOperationInput(*op, "scale_factor_width", + model_builder.AddScalarConstant(coreml_op_type, "scale_factor_width", scale_width)); + + if (is_linear) { + // we only allow these coord modes in the 'is supported' check, + // - half_pixel or pytorch_half_pixel with output size > 1 -> align_corners = false + // - align_corners -> align_corners = true + bool align_corners = coord_trans_mode == "align_corners"; + AddOperationInput(*op, "align_corners", + model_builder.AddScalarConstant(coreml_op_type, "align_corners", align_corners)); + } + } else { + assert(using_sizes); + int64_t target_height = output_sizes[num_sizes - 2]; + int64_t target_width = output_sizes[num_sizes - 1]; + + AddOperationInput(*op, "target_size_height", + model_builder.AddScalarConstant(coreml_op_type, "target_size_height", target_height)); + AddOperationInput(*op, "target_size_width", + model_builder.AddScalarConstant(coreml_op_type, "target_size_width", target_width)); + + if (is_linear) { + // we only allow these coord modes in the 'is supported' check, + // - half_pixel or pytorch_half_pixel with output size > 1 -> UNALIGN_CORNERS + // - align_corners -> STRICT_ALIGN_CORNERS + // - asymmetric -> DEFAULT + std::string sampling_mode_value; + if (coord_trans_mode == "asymmetric") { + sampling_mode_value = "DEFAULT"; + } else if (coord_trans_mode == "align_corners") { + sampling_mode_value = "STRICT_ALIGN_CORNERS"; + } else { + sampling_mode_value = "UNALIGN_CORNERS"; + } + + AddOperationInput(*op, "sampling_mode", + model_builder.AddScalarConstant(coreml_op_type, "sampling_mode", sampling_mode_value)); + } + } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + AddOperationOutput(*op, *output_defs[0]); + model_builder.AddOperation(std::move(op)); + } else // NOLINT +#endif + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto* coreml_upsample = layer->mutable_upsample(); + + // we already checked the mode must be NN or Bilinear in IsOpSupportedImpl + if (is_linear) { + coreml_upsample->set_mode(COREML_SPEC::UpsampleLayerParams_InterpolationMode_BILINEAR); + } else { + coreml_upsample->set_mode(COREML_SPEC::UpsampleLayerParams_InterpolationMode_NN); + } + + if (using_scales) { + coreml_upsample->add_scalingfactor(static_cast(output_scales[num_scales - 2])); + coreml_upsample->add_scalingfactor(static_cast(output_scales[num_scales - 1])); + } else { + auto scale_height = output_sizes[num_sizes - 2] / input_shape[input_rank - 2]; + auto scale_width = output_sizes[num_sizes - 1] / input_shape[input_rank - 1]; + coreml_upsample->add_scalingfactor(static_cast(scale_height)); + coreml_upsample->add_scalingfactor(static_cast(scale_width)); + } + + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = output_defs[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "Resize: input shape was not known"; return false; + } - const auto input_size = input_shape.size(); - if (input_size != 4) { - LOGS(logger, VERBOSE) << "Resize only support 4d shape, input is " - << input_size << "d shape"; + // as we allow empty shapes in the checks done by BaseOpBuilder::HasSupportedInputs we explicitly check for an empty + // an empty input here to be consistent. + // this should never happen in a real model though as a dim with value 0 (i.e. no input data) would typically be a + // dynamic dimension where a previous step had no output (e.g. Loop of zero interations, NonZero with no matches, + // NonMaxSupression with no boxes). + if (DoesShapeSpecifyZeroElements(input_shape)) { + LOGS(logger, VERBOSE) << "Resize input shape has with dimension values of 0 which is not supported."; return false; } - { // check attributes - NodeAttrHelper helper(node); - const auto mode = helper.Get("mode", "nearest"); - bool is_linear_resize = mode == "linear"; - bool is_nearest_resize = mode == "nearest"; - if (!is_linear_resize && !is_nearest_resize) { - LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode; + const auto input_rank = input_shape.size(); + if (input_params.create_mlprogram) { + if (input_rank < 3 || input_rank > 5) { + LOGS(logger, VERBOSE) << "Resize only supports 3D to 5D input. Got: " << input_rank << "D"; return false; } - - const auto exclude_outside = helper.Get("exclude_outside", 0); - if (exclude_outside != 0) { - LOGS(logger, VERBOSE) << "Resize does not support exclude_outside for now"; + } else { + if (input_rank != 4) { + LOGS(logger, VERBOSE) << "Resize only support 4d shape. Got: " << input_rank << "D"; return false; } + } - const auto coord_trans_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); - bool using_asymmetric = coord_trans_mode == "asymmetric"; - if (is_linear_resize) { - // TODO, add support of align_corners and half_pixel - if (!using_asymmetric) { - LOGS(logger, VERBOSE) << "Resize bilinear, unsupported coord_trans_mode, " << coord_trans_mode; - return false; - } - } else { - // nearest neighbor resizing - // For resize using nearest neighbor, we only support coord_trans_mode == "asymmetric" && nearest_mode == "floor" - if (!using_asymmetric) { - LOGS(logger, VERBOSE) << "Resize nearest neighbor, unsupported coord_trans_mode, " << coord_trans_mode; - return false; - } + // check attributes + NodeAttrHelper helper(node); - const auto nearest_mode = helper.Get("nearest_mode", "round_prefer_floor"); - if (nearest_mode != "floor") { - LOGS(logger, VERBOSE) << "Resize nearest neighbor, unsupported nearest_mode, " << nearest_mode; - return false; - } - } + if (helper.Get("antialias", 0) != 0) { + LOGS(logger, VERBOSE) << "Resize does not support antialias"; + return false; } - { // scales and sizes (if present) must be initializers - if (input_defs.size() < 3) { - LOGS(logger, VERBOSE) << "Input scales or sizes of Resize must be known"; - return false; - } + const auto& mode = helper.Get("mode", "nearest"); + bool is_linear = mode == "linear"; + bool is_nearest = mode == "nearest"; + if (!is_linear && !is_nearest) { + LOGS(logger, VERBOSE) << "Resize unsupported input mode: " << mode; + return false; + } - bool using_scales = input_defs.size() >= 3 && input_defs[2]->Exists(); - // scales - if (using_scales && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { - LOGS(logger, VERBOSE) << "scales input of Resize must be a constant initializer"; + if (is_nearest) { + const auto nearest_mode = helper.Get("nearest_mode", "round_prefer_floor"); + if (nearest_mode != "floor") { + LOGS(logger, VERBOSE) << "Resize only supports 'floor' nearest_mode. Got: " << nearest_mode; return false; } + } - // sizes - if (!using_scales && - (input_defs.size() < 4 || - !input_defs[3]->Exists() || - !input_params.graph_viewer.GetConstantInitializer(input_defs[3]->Name()))) { - LOGS(logger, VERBOSE) << "sizes input of Resize must be a constant initializer"; - return false; - } + if (helper.Get("exclude_outside", 0) != 0) { + LOGS(logger, VERBOSE) << "Resize does not support 'exclude_outside'"; + return false; + } - // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (using_scales) { - std::vector scales; - if (!GetResizeScales(initializers, node, scales, logger)) - return false; + const auto keep_aspect_ratio_policy = helper.Get("keep_aspect_ratio_policy", "stretch"); + if (keep_aspect_ratio_policy != "stretch") { + LOGS(logger, VERBOSE) << "Resize only supports keep_aspect_ratio_policy of 'stretch'. Got " + << keep_aspect_ratio_policy; + return false; + } - float scale_n = scales[0]; - float scale_c = scales[1]; - if (scale_n != 1.0f || scale_c != 1.0f) { - LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" - << "Resize of N/C channels are not supported" - << ", scale_n, " << scale_n << ", scale_c, " << scale_c; - return false; - } + // check for sizes first. this handles Resize-11 where scales was a required input but sizes were used if provided. + bool using_sizes = input_defs.size() >= 4 && input_defs[3]->Exists(); + bool using_scales = !using_sizes && input_defs.size() >= 3 && input_defs[2]->Exists(); - // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1 - // TODO support ResizeBilinear - float scale_h = scales[2]; - float scale_w = scales[3]; + if (!using_scales && !using_sizes) { + LOGS(logger, VERBOSE) << "Resize requires 'scales' or 'sizes' input"; + return false; + } + + // 'axes' is from opset 18 on and allows scales or sizes to have entries for the subset of axes. + // we fill with default values if necessary so that the processing is consistent across all supported opsets. + auto axes = GetAxes(helper, input_rank); + std::vector output_scales; + std::vector output_sizes; + + // make sure scales/sizes are constant initializers, and are only modifying the last two dimensions of the input. + if (using_scales) { + if (!GetValidatedResizeScales(input_params.graph_viewer, node, input_shape, axes, output_scales, logger)) { + return false; + } - // Onnx spec requires scale to be a positive float, so we are not checking that here + size_t num_scales = output_scales.size(); + float scale_h = output_scales[num_scales - 2]; + float scale_w = output_scales[num_scales - 1]; + + // NeuralNetwork supports upsample only with round numbers. + // + // ML Program results seem to match if round numbers are involved. When downsampling the scaling value should be + // 1 / . e.g. if input size is 8, scaling factor could be 1/8, 1/4 or 1/2. + if (scale_h >= 1.f && scale_w >= 1.f) { + // upsample (or no-op with both == 1.f that we won't bother special-casing) if (roundf(scale_h) != scale_h) { LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " is not a whole number"; return false; @@ -228,33 +421,57 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; return false; } - } else { - // we are using sizes - std::vector output_sizes; - if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) - return false; - - if (!IsStaticShape(input_shape)) { - LOGS(logger, VERBOSE) << "Input shape with dynamic dimensions is not supported."; + } else if (scale_h <= 1.f && scale_w <= 1.f) { + // downsample + if (input_params.create_mlprogram) { + auto h_in = input_shape[input_rank - 2]; + auto w_in = input_shape[input_rank - 1]; + + if (!utils::IsScalingByAFactorOfN(h_in, scale_h)) { + LOGS(logger, VERBOSE) << "Resize: downsampling scale " << scale_h + << " is not a factor of input height: " << h_in; + return false; + } + + if (!utils::IsScalingByAFactorOfN(w_in, scale_w)) { + LOGS(logger, VERBOSE) << "Resize: downsampling scale " << scale_w + << " is not a factor of input width: " << w_in; + return false; + } + + } else { + LOGS(logger, VERBOSE) << "Resize: downsampling is not supported."; return false; } + } else { + LOGS(logger, VERBOSE) << "Resize: scale_h: " << scale_h << " and scale_w: " << scale_w + << " must both be >= 1 or <= 1"; + return false; + } + } else { + assert(using_sizes); + + if (!GetValidatedResizeSizes(input_params.graph_viewer, node, input_shape, axes, output_sizes, logger)) { + return false; + } - auto output_size_n = output_sizes[0]; - auto output_size_c = output_sizes[1]; - if (output_size_n != input_shape[0] || output_size_c != input_shape[1]) { - LOGS(logger, VERBOSE) << "Output sizes of N/C channel should match the input sizes, " - << "Resize of N/C channels are not supported" - << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n - << ". input_size_c, " << input_shape[1] << ", output_size_c, " << output_size_c; + if (input_params.create_mlprogram) { + // no additional requirements + } else { + if (!IsStaticShape(input_shape)) { + // need to convert from sizes to scales when creating the NN layer, so the input H and W are required + LOGS(logger, VERBOSE) << "Resize input shape with dynamic dimensions is not supported."; return false; } - // For now we only support upscale, so the output_size_h and output_size_w should be an integer >= 1 + // For now we only support upsample, so the output_size_h and output_size_w should be an integer >= 1 // TODO support ResizeBilinear - auto output_size_h = output_sizes[2]; - auto output_size_w = output_sizes[3]; - auto input_size_h = input_shape[2]; - auto input_size_w = input_shape[3]; + auto input_size_h = input_shape[input_rank - 2]; + auto input_size_w = input_shape[input_rank - 1]; + + auto num_sizes = output_sizes.size(); // could be smaller than input_rank if axes was used + auto output_size_h = output_sizes[num_sizes - 2]; + auto output_size_w = output_sizes[num_sizes - 1]; // Onnx spec requires output sizes to be a positive integer, so we are not checking that here if (output_size_h % input_size_h != 0) { @@ -271,6 +488,92 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa } } + std::string coord_trans_mode = helper.Get("coordinate_transformation_mode", "half_pixel"); + bool using_asymmetric = coord_trans_mode == "asymmetric"; + + if (input_params.create_mlprogram) { + if (is_nearest) { + // Potential CoreML operators we could map to: + // + // image_resizing.upsample_nearest_neighbor + // - mode: nearest + // - coordinate_transformation_mode: asymmetric + // - 'scales' input + // + // image_resizing.resize_nearest_neighbor + // - mode: nearest + // - coordinate_transformation_mode: asymmetric + // - 'sizes' input + if (!using_asymmetric) { + LOGS(logger, VERBOSE) << "Resize with 'mode' of 'nearest' requires 'coordinate_transformation_mode' of " + "'asymmetric' . Got: " + << coord_trans_mode; + return false; + } + } else { + assert(is_linear); + // Potential CoreML operators we could map to: + // + // image_resizing.upsample_bilinear + // - mode: linear + // - 'scales' input + // - coordinate_transformation_mode + // - half_pixel -> align_corners = false + // - align_corners -> align_corners = true + // + // image_resizing.resize_bilinear + // - mode: linear + // - 'sizes' input + // - coordinate_transformation_mode -> sampling_mode + // - half_pixel -> UNALIGN_CORNERS + // - align_corners -> STRICT_ALIGN_CORNERS + // - asymmetric -> DEFAULT + // + + // if output size != 1, coordinate_transformation_mode of pytorch_half_pixel is the same as half_pixel + if (coord_trans_mode == "pytorch_half_pixel") { + int64_t h_out{0}, w_out{0}; + if (using_scales) { + size_t num_scales = output_scales.size(); + h_out = std::llround(input_shape[input_rank - 2] * output_scales[num_scales - 2]); + w_out = std::llround(input_shape[input_rank - 1] * output_scales[num_scales - 1]); + } else { + size_t num_sizes = output_sizes.size(); + h_out = output_sizes[num_sizes - 2]; + w_out = output_sizes[num_sizes - 1]; + } + + if (h_out > 1 && w_out > 1) { + coord_trans_mode = "half_pixel"; + } + } + + if (coord_trans_mode == "half_pixel" || + coord_trans_mode == "align_corners" || + (using_sizes && coord_trans_mode == "asymmetric")) { + // supported + + // FWIW we could calculate (if shape inferencing didn't already) the output sizes and convert a node with + // `scales` and co-ord mode of `asymmetric` to having `sizes` input so it's supported. + } else { + LOGS(logger, VERBOSE) << "Resize with 'mode' of 'linear' requires 'coordinate_transformation_mode' of " + "'half_pixel', or 'align_corners', or 'asymmetric' with 'sizes' input. Got: " + << coord_trans_mode; + + return false; + } + } + } else { + // NeuralNetwork checks + if (!using_asymmetric) { + // align_corners and half_pixel could be supported in ResizeBilinear but as NeuralNetwork is deprecated + // there's no known value to adding that. + LOGS(logger, VERBOSE) << "Resize only supports 'asymmetric' coordinate_transformation_mode. Got: " + << coord_trans_mode; + return false; + } + } + return true; } diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index 8f85ab2c09e7c..385588dbfdcb8 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -141,8 +141,17 @@ class ModelBuilder { // so we don't do a copy of the original initializer into the model. void AddInitializerToSkip(const std::string& tensor_name); - // There are some input which will not be used, add it to a list which will not - // be added to CoreML model, since CoreML does not like input unused + /// + /// Skip a non-initializer value, that is not used in the CoreML model, but was an input to a supported node. + /// + /// This is for a rare edge case where a value is an input to a node but is empty/unused, as the + /// CoreML model requires all model inputs to be consumed. + /// + /// + /// The only known use case for this currently is Resize, and that is largely due to how the unit tests are + /// setup rather than something you'd expect to see in a real model. + /// See ResizeOpBuilder::AddInitializersToSkip for more details. + /// void AddInputToSkip(const std::string& input_name); const std::string& GetUniqueName(const std::string& base_name); diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 0ba715cc7c6d9..a92fef81ac395 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -27,6 +27,7 @@ CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, coreml_flags_(coreml_flags), coreml_version_(coreml::util::CoreMLVersion()) { + LOGS_DEFAULT(VERBOSE) << "CoreML version: " << coreml_version_; if (coreml_version_ < MINIMUM_COREML_VERSION) { LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform."; } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index f53779058a8af..9c8a8712ca51c 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -17,6 +17,10 @@ #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/tunable/cuda_tuning_context.h" +#ifndef DISABLE_CONTRIB_OPS +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#endif + namespace onnxruntime { void RunOnUnload(std::function function); @@ -80,6 +84,14 @@ class CUDAExecutionProvider : public IExecutionProvider { bool IsNHWCPreferred() const { return info_.prefer_nhwc; } bool UseTF32() const { return info_.use_tf32; } +#ifndef DISABLE_CONTRIB_OPS + // Attention kernel options parsed from sdpa_kernel cuda provider option. + const AttentionKernelOptions* GetAttentionKernelOptions() const { + attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true); + return &attention_kernel_options_; + } +#endif + ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); } @@ -110,6 +122,11 @@ class CUDAExecutionProvider : public IExecutionProvider { // the tuning context might be altered when calling into a TunableOp mutable cuda::tunable::CudaTuningContext tuning_context_; +#ifndef DISABLE_CONTRIB_OPS + // Attention kernel options parsed from sdpa_kernel cuda provider option. + mutable AttentionKernelOptions attention_kernel_options_; +#endif + class PerThreadContext final { public: PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index c96381e3e68b1..31cf991a34fc9 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -34,6 +34,7 @@ constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_s constexpr const char* kPreferNHWCMode = "prefer_nhwc"; constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; constexpr const char* kUseTF32 = "use_tf32"; +constexpr const char* kSdpaKernel = "sdpa_kernel"; } // namespace provider_option_names } // namespace cuda @@ -117,6 +118,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc) .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) + .AddAssignmentToReference(cuda::provider_option_names::kSdpaKernel, info.sdpa_kernel) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -170,6 +172,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, + {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, }; return options; @@ -192,6 +195,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, + {cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 1cac3d1513698..0efad80f743df 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -79,6 +79,8 @@ struct CUDAExecutionProviderInfo { // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. bool use_tf32{true}; + int sdpa_kernel{0}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); @@ -91,6 +93,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { size_t value{0xbc9f1d34}; // seed // Bits: device_id (16), arena_extend_strategy/cudnn_conv_algo_search (reserved 2), boolean options (1 each) + // Do not exceed 32 bits here otherwise some bits will be lost in x86. size_t data = static_cast(info.device_id) ^ (static_cast(info.arena_extend_strategy) << 16) ^ (static_cast(info.cudnn_conv_algo_search) << 18) ^ @@ -109,6 +112,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { onnxruntime::HashCombine(info.gpu_mem_limit, value); onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); + onnxruntime::HashCombine(info.sdpa_kernel, value); // Memory pointers onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 288da23f35ec8..9d37a9775872f 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -94,6 +94,12 @@ class CudaKernel : public OpKernel { return provider_->UseTF32(); } +#ifndef DISABLE_CONTRIB_OPS + const AttentionKernelOptions* GetAttentionKernelOptions() const { + return provider_->GetAttentionKernelOptions(); + } +#endif + tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 7851da7fa91a3..b1d54e56ded4e 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -226,6 +226,7 @@ struct CUDA_Provider : Provider { info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; info.use_tf32 = params->use_tf32 != 0; + info.sdpa_kernel = params->sdpa_kernel; return std::make_shared(info); } @@ -260,6 +261,7 @@ struct CUDA_Provider : Provider { cuda_options.prefer_nhwc = internal_options.prefer_nhwc; cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; cuda_options.use_tf32 = internal_options.use_tf32; + cuda_options.sdpa_kernel = internal_options.sdpa_kernel; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index d75b9cc72ff4b..ef27f6c942f44 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -9,6 +9,7 @@ #include "core/graph/graph_viewer.h" #include "core/optimizer/initializer.h" #include "core/providers/common.h" +#include "core/providers/utils.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" #include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h" @@ -251,14 +252,34 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const N const Initializer unpacked_tensor(*scales); auto scales_data = unpacked_tensor.DataAsSpan(); input_is_nchw = scales_data[1] == 1.0F; - float const scale_n = scales_data[0]; - float const scale_c = input_is_nchw ? scales_data[1] : scales_data[3]; + const float scale_n = scales_data[0]; + const float scale_c = input_is_nchw ? scales_data[1] : scales_data[3]; + const float scale_h = input_is_nchw ? scales_data[2] : scales_data[1]; + const float scale_w = input_is_nchw ? scales_data[3] : scales_data[2]; + if (scale_n != 1.0f || scale_c != 1.0f) { LOGS_DEFAULT(VERBOSE) << "Scales of N/C channel should be 1" << "Resize of N/C channels are not supported" << ", scale_n, " << scale_n << ", scale_c, " << scale_c; return false; } + + // if downsampling the input size must be evenly divisible by the output size to match the onnx output + if (scale_h < 1.0f || scale_w < 1.0f) { + // we also require input_shape to be known to check + auto h_in = input_is_nchw ? input_shape[2] : input_shape[1]; + auto w_in = input_is_nchw ? input_shape[3] : input_shape[2]; + if (h_in == 0 || w_in == 0) { + LOGS_DEFAULT(VERBOSE) << "Input H and W must be known to downsample with scales"; + return false; + } + + if (!utils::IsScalingByAFactorOfN(h_in, scale_h) || + !utils::IsScalingByAFactorOfN(w_in, scale_w)) { + LOGS_DEFAULT(VERBOSE) << "Input size must be evenly divisible by output size when downsampling"; + return false; + } + } } else { const auto* sizes = graph_viewer.GetConstantInitializer(inputs[3].node_arg.Name()); if (!sizes) { diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 42788f2960197..ef45d6c85d6a9 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -274,6 +274,9 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph auto& attrs = node->GetAttributes(); const int64_t embed_mode = attrs.at(EMBED_MODE).i(); + // Only make path checks if model not provided as byte buffer + bool make_secure_path_checks = !GetModelPath(graph_viewer).empty(); + if (embed_mode) { // Get engine from byte stream. const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s(); @@ -284,6 +287,23 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); } + + if (weight_stripped_engine_refit_) { + const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); + std::string placeholder; + auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + placeholder, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + (*trt_engine_).get(), + false /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } } else { // Get engine from cache file. std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s(); @@ -343,7 +363,9 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, weight_stripped_engine_cache, - true /* path check for security */, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, (*trt_engine_).get(), true /* serialize refitted engine to disk */, detailed_build_log_); diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index 3be08d043da48..3af0143cbf14e 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -52,6 +52,8 @@ class TensorRTCacheModelHandler { std::string compute_capability, bool weight_stripped_engine_refit, std::string onnx_model_folder_path, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, bool detailed_build_log) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), @@ -59,6 +61,8 @@ class TensorRTCacheModelHandler { compute_capability_(compute_capability), weight_stripped_engine_refit_(weight_stripped_engine_refit), onnx_model_folder_path_(onnx_model_folder_path), + onnx_model_bytestream_(onnx_model_bytestream), + onnx_model_bytestream_size_(onnx_model_bytestream_size), detailed_build_log_(detailed_build_log) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); @@ -74,6 +78,8 @@ class TensorRTCacheModelHandler { std::string compute_capability_; bool weight_stripped_engine_refit_; std::string onnx_model_folder_path_; + const void* onnx_model_bytestream_; + size_t onnx_model_bytestream_size_; bool detailed_build_log_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 67cbc8f5d6f13..cdbb7bb2a8094 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1333,6 +1333,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv engine_cache_enable_ = info.engine_cache_enable; weight_stripped_engine_enable_ = info.weight_stripped_engine_enable; onnx_model_folder_path_ = info.onnx_model_folder_path; + onnx_model_bytestream_ = info.onnx_bytestream; + onnx_model_bytestream_size_ = info.onnx_bytestream_size; + if ((onnx_model_bytestream_ != nullptr && onnx_model_bytestream_size_ == 0) || + (onnx_model_bytestream_ == nullptr && onnx_model_bytestream_size_ != 0)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "When providing either 'trt_onnx_bytestream_size' or " + "'trt_onnx_bytestream' both have to be provided")); + } timing_cache_enable_ = info.timing_cache_enable; force_timing_cache_match_ = info.force_timing_cache; detailed_build_log_ = info.detailed_build_log; @@ -1757,7 +1765,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_ep_context_file_path: " << ep_context_file_path_ << ", trt_ep_context_embed_mode: " << ep_context_embed_mode_ << ", trt_cache_prefix: " << cache_prefix_ - << ", trt_engine_hw_compatible: " << engine_hw_compatible_; + << ", trt_engine_hw_compatible: " << engine_hw_compatible_ + << ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_; } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -2597,28 +2606,42 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, bool path_check, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log) { #if NV_TENSORRT_MAJOR >= 10 + bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; std::filesystem::path onnx_model_path{onnx_model_folder_path}; - onnx_model_path.append(onnx_model_filename); - if (path_check && IsAbsolutePath(onnx_model_path.string())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "For security purpose, the ONNX model path should be set with " - "a relative path, but it is an absolute path: " + - onnx_model_path.string()); - } - if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "The ONNX model path has '..'. For security purpose, it's not " - "allowed to point outside the directory."); - } + if (refit_from_file) { + if (!onnx_model_filename.empty()) { + onnx_model_path.append(onnx_model_filename); + } + if (onnx_model_path.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The ONNX model was not provided as path. " + "Please use provide an ONNX bytestream to enable refitting the weightless engine."); + } else { + // check if file path to ONNX is legal + if (path_check && IsAbsolutePath(onnx_model_path.string())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "For security purpose, the ONNX model path should be set with " + "a relative path, but it is an absolute path: " + + onnx_model_path.string()); + } + if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The ONNX model path has '..'. For security purpose, it's not " + "allowed to point outside the directory."); + } - if (!std::filesystem::exists(onnx_model_path)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "The ONNX model " + onnx_model_path.string() + - " does not exist."); + if (!(std::filesystem::exists(onnx_model_path) && std::filesystem::is_regular_file(onnx_model_path))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The ONNX model " + onnx_model_path.string() + + " does not exist."); + } + } } // weight-stripped engine refit logic @@ -2626,9 +2649,18 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); auto parser_refitter = std::unique_ptr( nvonnxparser::createParserRefitter(*refitter, trt_logger)); - if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + if (refit_from_file) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string(); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + } + } else { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from byte array"; + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem"); + } } if (refitter->refitCudaEngine()) { LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine."; @@ -3212,10 +3244,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } if (weight_stripped_engine_refit_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; + char* onnx = string_buf.data(); + size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, + onnx, + onnx_size, trt_engine.get(), true /* serialize refitted engine to disk */, detailed_build_log_); @@ -3685,6 +3722,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView onnx_model_folder_path_, engine_cache_path, false /* path check for security */, + onnx_model_bytestream_, + onnx_model_bytestream_size_, trt_engine, true /* serialize refitted engine to disk */, detailed_build_log_); @@ -3910,6 +3949,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con compute_capability_, weight_stripped_engine_enable_, onnx_model_folder_path_, + onnx_model_bytestream_, + onnx_model_bytestream_size_, detailed_build_log_); auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); if (status != Status::OK()) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index b58e86237860c..3f20314438564 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -274,13 +274,12 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; - /** - * Refit the weight-stripped engine - */ static common::Status RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, bool path_check, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log); @@ -305,6 +304,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool weight_stripped_engine_enable_ = false; bool weight_stripped_engine_refit_ = false; std::string onnx_model_folder_path_; + const void* onnx_model_bytestream_; + size_t onnx_model_bytestream_size_; bool build_heuristics_enable_ = false; bool sparsity_enable_ = false; int builder_optimization_level_ = 3; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index 9fe39f5921e1c..63b6d35072290 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -54,6 +54,8 @@ constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode"; constexpr const char* kEpContextFilePath = "trt_ep_context_file_path"; constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; +constexpr const char* kONNXBytestream = "trt_onnx_bytestream"; +constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size"; } // namespace provider_option_names } // namespace tensorrt @@ -61,6 +63,7 @@ constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { TensorrtExecutionProviderInfo info{}; void* user_compute_stream = nullptr; + void* onnx_bytestream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -122,10 +125,20 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineHwCompatible, info.engine_hw_compatible) + .AddValueParser( + tensorrt::provider_option_names::kONNXBytestream, + [&onnx_bytestream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + onnx_bytestream = reinterpret_cast(address); + return Status::OK(); + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); + info.onnx_bytestream = onnx_bytestream; return info; } @@ -173,6 +186,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, + {tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)}, + {tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)}, }; return options; } @@ -234,6 +249,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, + {tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast(info.trt_onnx_bytestream))}, + {tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)}, }; return options; } @@ -336,5 +353,7 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode; trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path); trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible; + trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream; + trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 3b859ea2da466..50b934fd5fcbc 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -34,6 +34,8 @@ struct TensorrtExecutionProviderInfo { std::string engine_cache_path{""}; bool weight_stripped_engine_enable{false}; std::string onnx_model_folder_path{""}; + const void* onnx_bytestream{nullptr}; + size_t onnx_bytestream_size{0}; bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 6430ffab09976..e242788ff389a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -116,6 +116,8 @@ struct Tensorrt_Provider : Provider { info.ep_context_embed_mode = options.trt_ep_context_embed_mode; info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; info.engine_hw_compatible = options.trt_engine_hw_compatible != 0; + info.onnx_bytestream = options.trt_onnx_bytestream; + info.onnx_bytestream_size = options.trt_onnx_bytestream_size; return std::make_shared(info); } diff --git a/onnxruntime/core/providers/utils.cc b/onnxruntime/core/providers/utils.cc index b2f9d265ca053..747b09e42aa21 100644 --- a/onnxruntime/core/providers/utils.cc +++ b/onnxruntime/core/providers/utils.cc @@ -23,5 +23,21 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& return Status::OK(); } #endif + +bool IsScalingByAFactorOfN(int64_t n, float scale) { + bool is_factor = false; + if (scale > 0.f && scale < 1.f) { + const double factor = 1.0 / scale; + const double factor_rounded = std::round(factor); + constexpr double epsilon = 1.0e-4; // arbitrarily small enough + if (std::abs(factor - factor_rounded) < epsilon) { + // result is integer. check if a factor of n + const int64_t factor_i = static_cast(factor_rounded); + is_factor = n % factor_i == 0; + } + } + + return is_factor; +} } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/providers/utils.h b/onnxruntime/core/providers/utils.h index 8cafdb8c05cc3..9ea8496a02f85 100644 --- a/onnxruntime/core/providers/utils.h +++ b/onnxruntime/core/providers/utils.h @@ -15,5 +15,10 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& OpKernelContext* context, int output_index); #endif +/// +/// Check if the reciprocal of 'scale' is a factor of 'n'. +/// e.g. a scale of 0.5 is 1/2, the reciprocal is 2, and 2 is a factor of any even number. +/// +bool IsScalingByAFactorOfN(int64_t n, float scale); } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc index ab31aa313cf6d..368c8c0358228 100644 --- a/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc +++ b/onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc @@ -466,7 +466,7 @@ std::string RetrieveEPContextCache( fs::path ep_ctx_fs_path(ep_ctx_model_loc); // Attr "ep_cache_context" stores a relative path. ep_ctx_fs_path.replace_filename(fs::path(ep_ctx_cache)); - // TODO: Validaion of the file location to make sure security is met. + // TODO: Validation of the file location to make sure security is met. if (!fs::exists(ep_ctx_fs_path) || !fs::is_regular_file(ep_ctx_fs_path)) { ORT_THROW("File for EP context cache is missing"); } diff --git a/onnxruntime/core/providers/vitisai/include/ep_context_utils.h b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h index 61a595cf1ae15..26546f422765c 100644 --- a/onnxruntime/core/providers/vitisai/include/ep_context_utils.h +++ b/onnxruntime/core/providers/vitisai/include/ep_context_utils.h @@ -14,8 +14,8 @@ namespace fs = std::filesystem; namespace onnxruntime { constexpr const uint8_t kXCCode = 1; -constexpr const uint8_t kDDCode = 2; -constexpr const uint8_t kVCode = 4; +[[maybe_unused]] constexpr const uint8_t kDDCode = 2; +[[maybe_unused]] constexpr const uint8_t kVCode = 4; static constexpr const char* kEPContextOp = "EPContext"; static constexpr const char* kMainContextAttr = "main_context"; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index f45b89649bfcb..036831df7a9cf 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -86,7 +86,7 @@ void VitisAIExecutionProvider::PrepareEPContextEnablement( model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string()); } std::string backend_cache_dir, backend_cache_key; - get_backend_compilation_cache(model_path_str_, graph_viewer, info_, kXCCode, backend_cache_dir, backend_cache_key, backend_cache_data_); + get_backend_compilation_cache(model_path_str_, graph_viewer, info_, kXCCode | kDDCode | kVCode, backend_cache_dir, backend_cache_key, backend_cache_data_); info_["cacheDir"] = backend_cache_dir; info_["cacheKey"] = backend_cache_key; // Create a new model, reusing the graph name, the op-domain-to-opset-version map, diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 09666c8039402..c752b5f849808 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -11,6 +11,7 @@ #include "core/framework/op_kernel.h" #include "core/optimizer/initializer.h" #include "core/providers/xnnpack/xnnpack_init.h" +#include "core/providers/utils.h" namespace onnxruntime { namespace xnnpack { @@ -68,9 +69,27 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, InlinedVector scale(4, 1.0F); if (scale_tensor) { const Initializer scale_val(*scale_tensor, node_unit.ModelPath()); - if (scale_val.DataAsSpan()[1] != 1.0F) { + const auto scales = scale_val.DataAsSpan(); + if (scales[1] != 1.0F) { break; } + + // downsampling output seems to require the output size to be a factor of the input to match ONNX + if (scales[2] < 1.0f || scales[3] < 1.0f) { + // we also require input_shape to be known to check + int64_t h_in = x_shape->dim(2).dim_value(); + int64_t w_in = x_shape->dim(3).dim_value(); + if (h_in < 0 || w_in < 0) { + break; + } + + float scale_h = scales[2]; + float scale_w = scales[3]; + if (!utils::IsScalingByAFactorOfN(h_in, scale_h) || + !utils::IsScalingByAFactorOfN(w_in, scale_w)) { + break; + } + } } if (size_tensor) { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f0eed91d70440..3fd6e84e0e5ce 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1609,7 +1609,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, - const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep) { + const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, + concurrency::ThreadPool* intra_op_thread_pool) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1617,7 +1618,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable); + optimizers_to_disable, intra_op_thread_pool); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2005,7 +2006,8 @@ common::Status InferenceSession::Initialize() { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( - ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep)); + ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, + cpu_ep, GetIntraOpThreadPoolToUse())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3167,7 +3169,8 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, - optimizers_to_disable_); + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3176,7 +3179,8 @@ common::Status InferenceSession::AddPredefinedTransformers( record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, - optimizers_to_disable_); + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); } }(); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 53e32a6221ec4..924158a26b927 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2484,6 +2484,10 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptionsWithValue, if (strcmp(key, "user_compute_stream") == 0) { tensorrt_options->has_user_compute_stream = 1; tensorrt_options->user_compute_stream = value; + } else if (strcmp(key, "trt_onnx_bytestream") == 0) { + tensorrt_options->trt_onnx_bytestream = value; + } else if (strcmp(key, "trt_onnx_bytestream_size") == 0) { + tensorrt_options->trt_onnx_bytestream_size = *reinterpret_cast(value); } return nullptr; #else diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 9ab4a82463d51..9bc50ce88ef16 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -12,6 +12,7 @@ #include "core/common/common.h" #include "core/common/optional.h" #include "core/common/type_utils.h" +#include "core/framework/int4.h" #include "test/util/include/test_random_seed.h" namespace onnxruntime { @@ -108,6 +109,22 @@ class RandomValueGenerator { return val; } + template + typename std::enable_if< + std::is_same_v || std::is_same_v, + std::vector>::type + Uniform(gsl::span dims, TInt4 min, TInt4 max) { + using UnpackedType = typename TInt4::UnpackedType; + std::vector data_int8 = Uniform(dims, min.GetElem(0), max.GetElem(0)); + std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); + for (size_t i = 0; i < data_int8.size(); i++) { + size_t r = i >> 1; + size_t c = i & 0x1; + data[r].SetElem(c, data_int8[i]); + } + return data; + } + // Gaussian distribution for float template typename std::enable_if< diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index a61e917b41e51..f0255d7ece84e 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -394,8 +394,8 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu } #if USE_MEMORY_EFFICIENT_ATTENTION - if (data.sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32 || - data.kv_sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32) { + if (data.sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32 || + data.kv_sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32) { kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 6214094a26c4f..b9af675afe74d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -117,22 +117,6 @@ class ModelTestBuilder { return MakeInput(shape, data); } - template - typename std::enable_if< - std::is_same_v || std::is_same_v, - NodeArg*>::type - MakeInputInt4(const std::vector& shape, typename TInt4::UnpackedType min, typename TInt4::UnpackedType max) { - using UnpackedType = typename TInt4::UnpackedType; - std::vector data_int8 = rand_gen_.Uniform(shape, min, max); - std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); - for (size_t i = 0; i < data_int8.size(); i++) { - size_t r = i >> 1; - size_t c = i & 0x1; - data[r].SetElem(c, data_int8[i]); - } - return MakeInput(shape, data); - } - template NodeArg* MakeInput(const std::optional>& shape, std::optional input_name = std::nullopt) { diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc new file mode 100644 index 0000000000000..3d117794104fa --- /dev/null +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -0,0 +1,425 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/span_utils.h" +#include "core/framework/int4.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gtest/gtest.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4127) +#endif // #if defined(_MSC_VER) + +struct QDQOpKeys { + const char* quantize_linear; + const char* dequantize_linear; +}; + +constexpr QDQOpKeys GetQDQOpKeys(bool use_contrib_qdq) { + if (use_contrib_qdq) { + return {"com.microsoft.QuantizeLinear", "com.microsoft.DequantizeLinear"}; + } + return {"QuantizeLinear", "DequantizeLinear"}; +} + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +// Input1 Input2 +// | | +// \ DQ +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* input2_arg = builder.MakeInput(input2_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{input2_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input1_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); +} + +// Input2 +// | +// DQ / +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* input2_arg = builder.MakeInput(input2_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {dq_output, input2_arg}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); +} + +// Input1 +// | +// \ DQ +// \ / +// MatMul +// | +// output +template +void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + NodeArg* weight_arg = nullptr; + + // add DQ + if constexpr (std::is_same_v || std::is_same_v) { + weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + } else { + weight_arg = builder.MakeInitializer(weight_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + NodeArg* zp_arg; + if constexpr (std::is_same_v || std::is_same_v) { + zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + } else { + zp_arg = builder.MakeInitializer(scale_shape, 0, 2); + } + + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { + // DQ contrib op schema is not updated to support blocked quantization + // block size too small + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + // block size not 2's power + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + // not axis 0 + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + // not rank 2 + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ +// \ / +// MatMul +// | DQ +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulConverted(const std::vector& input1_shape, + const std::vector& weight1_shape, + const std::vector& weight2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + auto scale1_shape = std::vector{weight1_shape}; + auto scale2_shape = std::vector{weight2_shape}; + scale1_shape[axis] = (scale1_shape[axis] + block_size - 1) / block_size; + scale2_shape[axis] = (scale2_shape[axis] + block_size - 1) / block_size; + + auto* weight1_arg = builder.MakeInitializer(weight1_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* weight2_arg = builder.MakeInitializer(weight2_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + auto* matmul1_output = builder.MakeIntermediate(); + + auto* scales1_arg = builder.MakeInitializer(scale1_shape, 8.0f, 12.0f); + auto* scales2_arg = builder.MakeInitializer(scale2_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp1_arg = builder.MakeInitializer(scale1_shape, T(0, 0), T(2, 0)); + auto* zp2_arg = builder.MakeInitializer(scale2_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); + builder.AddNode("MatMul", {matmul1_output, dq2_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 862408f31f004..52ac2a2541a79 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -517,7 +517,7 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, NodeArg* input_arg = nullptr; if constexpr (std::is_same_v || std::is_same_v) { - input_arg = builder.MakeInputInt4(input_shape, InputType::min_val, InputType::max_val); + input_arg = builder.MakeInput(input_shape, InputType(InputType::min_val, 0), InputType(InputType::max_val, 0)); dq_zp = InputType(static_cast(InputType::max_val / 2)); q_zp = OutputType(static_cast(OutputType::max_val / 2)); } else { diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1c77121ba9df1..1638851daf65a 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2763,6 +2763,57 @@ TEST(QDQTransformerTests, Clip) { } } +// Test that the ReluQuantFusion transformer only runs for optimization level >= 2. +TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) { + auto test_case = [&](TransformerLevel opt_level, int8_t zp) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 2, 2, 2}, + {-4, -3, -2, 0, 1, 2, 3, 4}); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, 1.0f, zp, dq_output); + + // add Relu + auto* relu_output = builder.MakeIntermediate(); + builder.AddNode("Relu", {dq_output}, {relu_output}); + + // add Q + DQ + auto* q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(relu_output, 1.0f, zp, q_output); + builder.AddDequantizeLinearNode(q_output, 1.0f, zp, output_arg); + }; + + auto check_relu_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + // Only fuse relu into Q if level >= 2 and zero_point == -128 for int8. + // Level1 graph: input -> DQ -> Relu -> Q -> DQ -> output + // Level2+ graph: input -> DQ -> output (QuantReluFusion + QDQFinalCleanupTransformer transformers) + const bool fuse_relu = (zp == -128) && + (opt_level == TransformerLevel::Level2 || opt_level == TransformerLevel::Level3); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], fuse_relu ? 0 : 1); + EXPECT_EQ(op_to_count["Relu"], fuse_relu ? 0 : 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], fuse_relu ? 1 : 2); + }; + + constexpr float epsilon = std::numeric_limits::epsilon(); + + TransformerTester(build_test_case, check_relu_graph, + TransformerLevel::Default, + opt_level, + 18, + epsilon, + epsilon); + }; + + test_case(TransformerLevel::Level1, -128); // Will not fuse Relu into QuantizeLinear due to level1 opt. + test_case(TransformerLevel::Level2, -128); // Will fuse Relu into QuantizeLinear. + test_case(TransformerLevel::Level3, -128); // Will fuse Relu into QuantizeLinear. + test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128 +} + TEST(QDQTransformerTests, Concat) { auto test_case = [&](const std::vector>& input_shapes, int64_t axis, diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index ab436d4f83e86..84c3bc16346f3 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -143,6 +143,7 @@ namespace perftest { "\t-D [Disable thread spinning]: disable spinning entirely for thread owned by onnxruntime intra-op thread pool.\n" "\t-Z [Force thread to stop spinning between runs]: disallow thread from spinning during runs to reduce cpu usage.\n" "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" + "\t-l Provide file as binary in memory by using fopen before session creation.\n" "\t-h: help\n"); } #ifdef _WIN32 @@ -205,7 +206,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzn"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznl"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -389,6 +390,9 @@ static bool ParseSessionConfigs(const std::string& configs_string, case 'n': test_config.run_config.exit_after_session_creation = true; break; + case 'l': + test_config.model_info.load_via_path = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index ff782da35cbe6..92d732fba2a0a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -5,6 +5,7 @@ #include "ort_test_session.h" #include #include +#include #include #include #include @@ -816,8 +817,21 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } - session_ = Ort::Session(env, performance_test_config.model_info.model_file_path.c_str(), session_options); - + if (!performance_test_config.model_info.load_via_path) { + session_ = Ort::Session(env, performance_test_config.model_info.model_file_path.c_str(), session_options); + } else { + std::ifstream file(performance_test_config.model_info.model_file_path.c_str(), + std::ios::binary | std::ios::in | std::ios::ate); + if (file.is_open()) { + const std::streamsize fsize = file.tellg(); + file.seekg(0, std::ios_base::beg); + std::vector model_bytes(narrow(fsize)); + file.read(model_bytes.data(), fsize); + session_ = Ort::Session(env, model_bytes.data(), model_bytes.size(), session_options); + } else { + ORT_THROW("Model file could not be opened.\n"); + } + } size_t output_count = session_.GetOutputCount(); output_names_.resize(output_count); Ort::AllocatorWithDefaultOptions a; diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 70a6b12690d5d..209fb55fe93d4 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -29,6 +29,7 @@ struct ModelInfo { std::basic_string model_file_path; std::basic_string input_file_path; std::basic_string result_file_path; + bool load_via_path = false; }; struct MachineConfig { diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 496f2213e9d32..111520ef03e26 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -227,28 +227,33 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e } TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { - OpTester test("Resize", 13); - std::vector roi{}; - std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; + auto run_test = [](bool scales_in_initializer) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; - test.AddAttribute("mode", "linear"); + test.AddAttribute("mode", "linear"); - constexpr int64_t N = 1, C = 1, H = 2, W = 4; - std::vector X = { - 1.0f, 2.0f, 3.0f, 4.0f, - 5.0f, 6.0f, 7.0f, 8.0f}; + constexpr int64_t N = 1, C = 1, H = 2, W = 4; + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; - test.AddInput("X", {N, C, H, W}, X); - test.AddInput("roi", {0}, roi); - test.AddInput("scales", {4}, scales); + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {4}, scales, scales_in_initializer); - std::vector Y = {2.66666651f, 4.3333331f}; + std::vector Y = {2.66666651f, 4.3333331f}; - test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); - // QNN: result diff - // TRT: Segmentation fault in A100 - std::unordered_set excluded_providers({kQnnExecutionProvider}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers)); + test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); + // QNN: result diff + // TRT: Segmentation fault in A100 + std::unordered_set excluded_providers({kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers)); + }; + + run_test(false); + run_test(true); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) { @@ -327,13 +332,14 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { // Since NNAPI(TFLite) only using the scale calculate using the input/output size // For the above test (ResizeOpLinearDownSampleTest_4DBilinear) // The output size is [1,1,2,4].*[1,1,0.6,0.6]=[1,1,1,2] -// NNAPI will recaluclate the scales as the output size divided by input size +// NNAPI will recalculate the scales as the output size divided by input size // scales = [1,1,1,2]./[1,1,2,4] = [1,1,0.5,0.5] // See:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h // So the result of the above example will be different than CPU EP -// Add the following 2 tests to test with scales valid to NNAPI +// Add the following 2 tests to test with scales valid to NNAPI. +// CoreML also doesn't handle a scale that doesn't divide the input size evenly. TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -360,8 +366,38 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) { run_test(true); } +// Downsize with factor being an odd number (1/3) +TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1_OddNumber) { + // To test NNAPI EP, we need the scales/sizes to be in initializers + auto run_test = [](bool scales_in_initializer) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 1.0f, (1.f / 3), (1.f / 3)}; + + test.AddAttribute("mode", "linear"); + + constexpr int64_t N = 1, C = 1, H = 3, W = 6; + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f}; + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {4}, scales, scales_in_initializer); + + std::vector Y = {8.f, 11.f}; + + test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); + }; + + run_test(false); + run_test(true); +} + TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1_WithSizes) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_and_sizes_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -389,8 +425,32 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1_WithSizes) { run_test(true); } +// test handling for opset 11. scales input is provided but should be ignored in favor of sizes +TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1_WithSizesOpset11) { + OpTester test("Resize", 11); + std::vector roi{}; + std::vector scales{}; + constexpr int64_t N = 1, C = 1, H = 2, W = 4; + std::vector sizes{N, C, 1, 2}; + test.AddAttribute("mode", "linear"); + + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {0}, scales); + test.AddInput("sizes", {4}, sizes, true); // add as initializer so CoreML EP can take + + std::vector Y = {3.5f, 5.5f}; + + test.AddOutput("Y", sizes, Y); + test.Run(); +} + TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -416,15 +476,51 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners) { run_test(false); -#ifdef USE_NNAPI - // NNAPI will need the scales as an initializer +#if defined(USE_NNAPI) || defined(USE_COREML) + // NNAPI and CoreML need the scales as an initializer + // Also tensor RT EP will fail if scales is an initializer but will pass if it is not + run_test(true); +#endif +} + +TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners_sizes) { + // To test NNAPI EP, we need the scales/sizes to be in initializers + auto run_test = [](bool scales_in_initializer) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{}; + std::vector sizes{1, 1, 1, 2}; + + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", "align_corners"); + + constexpr int64_t N = 1, C = 1, H = 2, W = 4; + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("", {0}, scales); + test.AddInput("sizes", {4}, sizes, scales_in_initializer); + + std::vector Y = {1.0f, 4.0f}; + + test.AddOutput("Y", {N, C, 1, 2}, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); + }; + + run_test(false); + +#if defined(USE_NNAPI) || defined(USE_COREML) + // NNAPI and CoreML will need the scales as an initializer // Also tensor RT EP will fail if scales is an initializer but will pass if it is not run_test(true); #endif } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uint8) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -456,7 +552,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int8) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -622,7 +718,7 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric_scales) { } TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -668,7 +764,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { } TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_int8) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; diff --git a/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc b/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc new file mode 100644 index 0000000000000..b2e986f680763 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef DISABLE_CONTRIB_OPS + +#include "contrib_ops/cuda/bert/attention_kernel_options.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "test/util/include/scoped_env_vars.h" +#include "gtest/gtest.h" + +#include +#include + +using onnxruntime::AttentionKernelOptions; +using onnxruntime::contrib::attention::AttentionBackend; + +namespace onnxruntime { +namespace test { + +TEST(AttentionKernelOptionsTest, NonZeroValue) { + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION) | static_cast(AttentionBackend::EFFICIENT_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_TRUE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_FUSED_ATTENTION) | static_cast(AttentionBackend::MATH); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_TRUE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_TRUE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_TRUE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + { + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::TRT_CROSS_ATTENTION) | static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_TRUE(options.UseTrtCrossAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + // Test environment variables are ignored when option value is non-zero + // Test default min sequence lengths are zeros + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}}}; + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0); + } + + // Test min sequence lengths can be parsed from environment variables when option value is non-zero + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"}, + {onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}}; + AttentionKernelOptions options; + int value = static_cast(AttentionBackend::FLASH_ATTENTION); + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_FALSE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256); + } +} + +// Test all environment variables take effect when option value is 0. +TEST(AttentionKernelOptionsTest, DefaultOptionWithEnvVar) { + constexpr int value = 0; + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"}, + {onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}}; + AttentionKernelOptions options; + options.InitializeOnce(value, false); + ASSERT_TRUE(options.UseFlashAttention()); + ASSERT_TRUE(options.UseEfficientAttention()); + ASSERT_TRUE(options.UseTrtFusedAttention()); + ASSERT_TRUE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_TRUE(options.UseTrtFlashAttention()); + ASSERT_TRUE(options.UseTrtCrossAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + ASSERT_TRUE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256); +} + +// Test default min sequence lengths when environment variables are not set. +TEST(AttentionKernelOptionsTest, DefaultMinSeqLens) { + constexpr int value = 0; + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}}}; + AttentionKernelOptions options; + options.InitializeOnce(value, false); + ASSERT_FALSE(options.UseFlashAttention()); + ASSERT_FALSE(options.UseEfficientAttention()); + ASSERT_FALSE(options.UseTrtFusedAttention()); + ASSERT_FALSE(options.UseCudnnFlashAttention()); + ASSERT_TRUE(options.UseUnfusedAttention()); + ASSERT_FALSE(options.UseTrtFlashAttention()); + ASSERT_FALSE(options.UseTrtCrossAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + ASSERT_FALSE(options.UseTrtCausalAttention()); + EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), + onnxruntime::contrib::attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); + EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), + onnxruntime::contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32); +} + +} // namespace test +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 2b5b82d0fc16a..63327a028c6f4 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -122,6 +122,18 @@ void CreateBaseModel(const PathString& model_name, status = onnxruntime::Model::Save(model, model_name); } +std::vector ReadFileFromDisk(const PathString& path) { + std::fstream file(path.c_str(), std::fstream::binary | std::fstream::in | std::fstream::ate); + std::vector file_bytes; + if (file.is_open()) { + auto fsize = file.tellg(); + file.seekg(0, std::ios_base::beg); + file_bytes.resize(fsize); + file.read(file_bytes.data(), fsize); + } + return file_bytes; +} + bool HasCacheFileWithPrefix(const std::string& prefix, std::string file_dir = "") { std::filesystem::path target_dir; if (file_dir.empty()) { @@ -360,7 +372,8 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { } TEST(TensorrtExecutionProviderTest, EPContextNode) { - PathString model_name = ORT_TSTR("EPContextNode_test.onnx"); + std::string model_name_str = "EPContextNode_test.onnx"; + PathString model_name = ToPathString(model_name_str); std::string graph_name = "EPContextNode_test"; std::string sess_log_id = "EPContextNode_test"; std::vector dims = {1, 3, 2}; @@ -461,11 +474,11 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { */ InferenceSession session_object3{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params3; - model_name = ToPathString(params.trt_ep_context_file_path); + PathString ctx_model_name = ToPathString(params.trt_ep_context_file_path); params3.trt_engine_cache_enable = 1; execution_provider = TensorrtExecutionProviderWithOptions(¶ms3); EXPECT_TRUE(session_object3.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - status = session_object3.Load(model_name); + status = session_object3.Load(ctx_model_name); ASSERT_TRUE(status.IsOK()); status = session_object3.Initialize(); ASSERT_TRUE(status.IsOK()); @@ -490,10 +503,10 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { */ InferenceSession session_object4{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params4; - model_name = ORT_TSTR("./context_model_folder/EPContextNode_test_ctx.onnx"); + ctx_model_name = ToPathString("./context_model_folder/EPContextNode_test_ctx.onnx"); execution_provider = TensorrtExecutionProviderWithOptions(¶ms4); EXPECT_TRUE(session_object4.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - status = session_object4.Load(model_name); + status = session_object4.Load(ctx_model_name); ASSERT_TRUE(status.IsOK()); status = session_object4.Initialize(); ASSERT_TRUE(status.IsOK()); @@ -514,7 +527,6 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { params5.trt_dump_ep_context_model = 1; params5.trt_ep_context_embed_mode = 1; params5.trt_ep_context_file_path = "EP_Context_model_2.onnx"; - model_name = ORT_TSTR("EPContextNode_test.onnx"); execution_provider = TensorrtExecutionProviderWithOptions(¶ms5); EXPECT_TRUE(session_object5.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); status = session_object5.Load(model_name); @@ -528,10 +540,10 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { InferenceSession session_object6{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params6; params6.trt_ep_context_embed_mode = 1; - model_name = ToPathString(params5.trt_ep_context_file_path); + ctx_model_name = ToPathString(params5.trt_ep_context_file_path); execution_provider = TensorrtExecutionProviderWithOptions(¶ms6); EXPECT_TRUE(session_object6.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - status = session_object6.Load(model_name); + status = session_object6.Load(ctx_model_name); ASSERT_TRUE(status.IsOK()); status = session_object6.Initialize(); ASSERT_TRUE(status.IsOK()); @@ -543,6 +555,61 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { // Y: 1, 3, 3, 2, 2, 2 // Z: 1, 3, 3, 2, 2, 2 RunSession(session_object6, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 7: Run context model with ONNX in memory + */ + auto model_bytes = ReadFileFromDisk(model_name); + std::string ctx_model_name_str = "EP_Context_model_weight_stripped.onnx"; + ctx_model_name = ToPathString(ctx_model_name_str); + InferenceSession session_object7{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params7; + params7.trt_dump_ep_context_model = 1; + params7.trt_ep_context_embed_mode = 1; + params7.trt_weight_stripped_engine_enable = 1; + params7.trt_ep_context_file_path = ctx_model_name_str.c_str(); + execution_provider = TensorrtExecutionProviderWithOptions(¶ms7); + EXPECT_TRUE(session_object7.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object7.Load(model_bytes.data(), static_cast(model_bytes.size())); + ASSERT_TRUE(status.IsOK()); + status = session_object7.Initialize(); + std::cerr << status.ErrorMessage(); + ASSERT_TRUE(status.IsOK()); + RunSession(session_object7, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 7: Refit weightless context model with ONNX in memory + */ + auto ctx_model_bytes = ReadFileFromDisk(ctx_model_name); + InferenceSession session_object8{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params8; + params8.trt_weight_stripped_engine_enable = 1; + params8.trt_onnx_bytestream = model_bytes.data(); + params8.trt_onnx_bytestream_size = model_bytes.size(); + execution_provider = TensorrtExecutionProviderWithOptions(¶ms8); + EXPECT_TRUE(session_object8.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object8.Load(ctx_model_bytes.data(), static_cast(ctx_model_bytes.size())); + std::cerr << status.ErrorMessage(); + ASSERT_TRUE(status.IsOK()); + status = session_object8.Initialize(); + std::cerr << status.ErrorMessage(); + ASSERT_TRUE(status.IsOK()); + RunSession(session_object8, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 7: Refit weightless context model with ONNX from disk + */ + InferenceSession session_object9{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params9; + params9.trt_weight_stripped_engine_enable = 1; + params9.trt_onnx_model_folder_path = model_name_str.c_str(); + execution_provider = TensorrtExecutionProviderWithOptions(¶ms9); + EXPECT_TRUE(session_object9.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object9.Load(ctx_model_bytes.data(), static_cast(ctx_model_bytes.size())); + ASSERT_TRUE(status.IsOK()); + status = session_object9.Initialize(); + ASSERT_TRUE(status.IsOK()); + RunSession(session_object9, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); } TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e4814aa7fc033..892e7de8bb6ed 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -446,6 +446,8 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("use_tf32", ["1", "0"]) + test_get_and_set_option_with_values("sdpa_kernel", ["0", "1", "2"]) + option["gpu_external_alloc"] = "0" option["gpu_external_free"] = "0" option["gpu_external_empty_cache"] = "0" diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 1bbb933f66ba4..3b3790ba06599 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -17,6 +17,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Pow|Only supports cases when both inputs are fp32.| |ai.onnx:Relu|| |ai.onnx:Reshape|| +|ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| |ai:onnx:Tanh||