diff --git a/.vscode/settings.json b/.vscode/settings.json index c4a08e3232a82..2f2adc78f6de9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -13,6 +13,7 @@ "editor.codeActionsOnSave": { "source.organizeImports": true }, + "editor.defaultFormatter": "ms-python.black-formatter" }, // Enable Python linting and Pylance type checking "python.analysis.typeCheckingMode": "basic", diff --git a/cmake/deps.txt b/cmake/deps.txt index 49142372ab86e..e065cacdfc423 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -54,4 +54,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/a4f72a314a85732ed67d5aa8d1088d207a7e0e61.zip;f57357ab6d300e207a632d034ebc8aa036a090d9 +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 9d9b006c595bb..c900f4d4b09a5 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -282,11 +282,7 @@ endif() # Assemble the Apple static framework (iOS and macOS) if(onnxruntime_BUILD_APPLE_FRAMEWORK) - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) - else() # macOS - set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) - endif() + set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT}) # Setup the various directories required. Remove any existing ones so we start with a clean directory. set(STATIC_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/static_libraries) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index baea52e84ace2..6f09583199ffd 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -86,6 +86,8 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/core/optimizer/*.cc" "${ORTTRAINING_SOURCE_DIR}/core/optimizer/compute_optimizer/*.h" "${ORTTRAINING_SOURCE_DIR}/core/optimizer/compute_optimizer/*.cc" + "${ORTTRAINING_SOURCE_DIR}/core/optimizer/memory_optimizer/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/optimizer/memory_optimizer/*.cc" ) endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index f2a16fb29dc62..cf298aee9fa85 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -172,10 +172,8 @@ target_link_libraries(${target} PRIVATE cuda) endif() - if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) - include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) - endif() + include(cutlass) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index a52e941b235b4..df62199dc2b42 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -783,7 +783,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) - target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock) + target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index 02b30af9eef52..15844dd917744 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index b09da41a8..fca2bdf69 100644 +index 04674124c..12e8b8b00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ endif() @@ -48,7 +48,18 @@ index b09da41a8..fca2bdf69 100644 ## tidy include(EnableCompilerWarnings) -@@ -489,11 +466,3 @@ rocm_install(FILES +@@ -376,7 +353,9 @@ if(BUILD_DEV) + add_compile_options(-Werror -Weverything) + endif() + #add flags to reduce the size of binaries +-add_compile_options(-Oz -flto=thin) ++# -flto requires ORT to use a linker that support LTO and -flto flag shoud be passed to linker together. ++# add_compile_options(-Oz -flto=thin) ++add_compile_options(-Oz) + message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +@@ -482,11 +461,3 @@ rocm_install(FILES set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") @@ -61,7 +72,7 @@ index b09da41a8..fca2bdf69 100644 - HEADER_ONLY -) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -index a0478c9f0..1e7782cd4 100644 +index 9cb5d0e9a..141a46f3d 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -44,8 +44,14 @@ function(add_instance_library INSTANCE_NAME) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 395996f0fa4b9..268ee3960e75a 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -451,6 +451,8 @@ onnxruntime_add_static_library(winml_lib_api ${winml_lib_api_dir}/impl/TensorKindFrom.h ${winml_lib_api_dir}/impl/TensorMemoryBufferReference.h ${winml_lib_api_dir}/NumericData.cpp + ${winml_lib_api_dir}/HardwareCoreEnumerator.cpp + ${winml_lib_api_dir}/HardwareCoreEnumerator.h ${winml_lib_api_dir}/ImageFeatureDescriptor.cpp ${winml_lib_api_dir}/ImageFeatureDescriptor.h ${winml_lib_api_dir}/ImageFeatureValue.cpp diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index e9ceae00a684d..0147a937db81d 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -20,70 +20,115 @@ Not all models and recipes need this optimizer technique. Imagine if your traini ## Quick trial 1. Make sure ONNX Runtime training wheel is installed and correctly configured. -2. Integrate models using `ORTModule`, be noted log_level should be equal to or lower than DEVINFO. - > ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO)) -3. Run the training as usual and redirect all outputs into the log file; then stop it after training a few steps. -4. Check the logging file, and search "Summary", you could find something like this: +2. Integrate models using `ORTModule`, be noted log_level should be equal or lower than INFO. + > ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO)) +3. Run the training as usual; then stop it after training few steps. +4. Check the logs, you could find something like this: ``` - MemoryOptimizer Summary: - User config: - - ================================= - ########Recompute######## - Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 - -------------------------------- - Subgraph: FastGelu+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 - ================================= - ########RecomputeWithCompromise######## - Subgraph: Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 - -------------------------------- - ================================= + Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_CONFIG=, available configs: + Config Freq Max Saving(B) Saving Symbolic(Bytes) + - Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) + - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + + + Note 1: use comma as delimiter to enable multiple memory optimization plans at the same time: + export ORTMODULE_MEMORY_OPT_CONFIG=,,... + Note 2: memory saving is calculated based on the 1st batch symbolic dim values: + inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024, ``` -5. As shown above, 'Subgraph' shows 1) a string representative for a re-computable subgraph; and 2) current status of memory optimization. All are disabled for recompute in this case. -6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, 12 FastGelu related subgraphs are allowed to recompute. -`FastGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `12` means the initial 12 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. +5. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case. +6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, `6` `BiasGelu+` related subgraphs are allowed to recompute. +`BiasGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `6` means the initial 6 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed. ``` - export ORTMODULE_MEMORY_OPT_CONFIG="FastGelu+:1:12" + export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:6" # Use comma as separator for enabling more than one subgraphs. ``` -7. Then run the training again, you will see logs like this: +7. Then run the training again, and you will see logs like this: ``` - MemoryOptimizer Summary: - User config: - **FastGelu+:1:12** - ================================= - ########Recompute######## - Subgraph: CumSum+Sub+Mul+Unsqueeze+Cast+Mul+Cast+Reshape+Mul+FusedMatMul+Add+Reshape+Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:23 - -------------------------------- - Subgraph: FastGelu+ - OptimizationType: **Recompute (requested_count=12, actual applied_count=12)** - Patterns: - PatternShape:input_ids_dim0 x input_ids_dim1 x 4096 x Frequency:24 - ================================= - ########RecomputeWithCompromise######## - Subgraph: Cast+Where+Softmax+ - OptimizationType: Disabled - Patterns: - PatternShape:input_ids_dim0 x 16 x input_ids_dim1 x input_ids_dim1 x Frequency:24 - -------------------------------- - ================================= + Memory Optimizer : ON : User config: Reshape+Where+BiasSoftmax+:1:-1, probe level: 1, available configs: + Config Freq Max Saving(B) Saving Symbolic(Bytes) + - Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) + - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 + - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) + - Plan 5 : ON : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) + - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) + - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 + - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` 8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well. +## Optimization Configuration + +The basic optimization unit is represented with a unique `cluster id`, for example `BiasGelu+` is one `cluster id`. +Following `cluster id` is the `optimization strategy`: 0 - none, 1 - recompute, 2 - recompute with compromised memory saving. +Following `optimization strategy` is the `request count` to apply the given optimization. Using `-1` to apply all. This would give user a bit more flexibility to avoid unnecessary memory saving. + ## Compromised Recompute -If you check the above logs, there is a separate section called "RecomputeWithCompromise". Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it. +If you check the above logs, there is a config `Cast+:2:-1`, `2` indicates it's a recomputation than can save part of the stashed activation size, not all. Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it. + +## Memory Optimization Debug Infos + +Using following log level +> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO)) + +Besides the logs shown in `LogLevel.INFO`, you can also see different node patterns that can apply different optimization options. + +The way we get the table: +- For a specific node, it might has different optimization options, we [generates](../orttraining/orttraining/core/optimizer/memory_optimizer/common.h#L124C26-L124C26) a hash (called `Node Cluster ID`) for the node according to all available optimization options. +- Map all nodes having same `Node Cluster ID` in buckets, each bucket is displayed as one row. + +``` +MemoryInsight Summary - User config: not provided +=========================================================================================================================================== +|Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|6 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 | +| | Stashed Activations: | +| | - ReuseFreq : Output 0(6), | +| | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(32)*(240))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|5 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph FusedMatMul+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(inputs_input_ids_dim1)*(10240))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|5 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(32)*(inputs_input_ids_dim1)*(inputs_input_ids_dim1))], byte/elem: 2, 100% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| +|1 |For each row options are mutually exclusive, only one of them can be enabled. | +| | | +| |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 100% saved | +| | | +| |>>Option 2 : RecomputeWithCompromise subgraph Cast+ | +| | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:2:-1 | +| | Stashed Activations: | +| | - Output 0 : [((inputs_input_ids_dim0)*(1)*(1)*(inputs_input_ids_dim1))], byte/elem: 4, 50% saved | +|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _| + +``` ## Notes diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 12733c3551704..7fa89cca381d9 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -269,6 +269,15 @@ data sparsity based performance optimizations. unset ORTMODULE_CACHE_DIR # Disable ``` +#### ORTMODULE_USE_EFFICIENT_ATTENTION + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_EFFICIENT_ATTENTION=1 + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* @@ -397,6 +406,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results ``` +#### ORTMODULE_USE_FLASH_ATTENTION + +- **Feature Area**: *ORTMODULE/TritonOp* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_FLASH_ATTENTION=1 + ``` + #### ORTMODULE_TRITON_DEBUG - **Feature Area**: *ORTMODULE/TritonOp* diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index b3783696b8d78..82a1c1de83523 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -2,34 +2,17 @@ // Licensed under the MIT License. #pragma once -#include -#include + #include -#include #include -#include "core/common/gsl.h" -#include "onnxruntime_config.h" - -#ifndef DISABLE_ABSEIL -// Need to include abseil inlined_vector.h header directly here -// as hash tables cause CUDA 10.2 compilers to fail. inlined_vector.h is fine. -#ifdef _MSC_VER -#pragma warning(push) -// C4127: conditional expression is constant -#pragma warning(disable : 4127) -// C4324: structure was padded due to alignment specifier -// Usage of alignas causes some internal padding in places. -#pragma warning(disable : 4324) -#endif - -#include - -#ifdef _MSC_VER -#pragma warning(pop) -#endif -#endif // DISABLE_ABSEIL +#include +#include +#include +#include "core/common/gsl.h" +#include "core/common/inlined_containers_fwd.h" #include "core/common/span_utils.h" +#include "onnxruntime_config.h" namespace onnxruntime { #ifdef __GNUC__ @@ -41,18 +24,10 @@ namespace onnxruntime { constexpr size_t kTensorShapeSmallBufferElementsSize = 5; -#ifndef DISABLE_ABSEIL // Use this type to build a shape and then create TensorShape. -using TensorShapeVector = absl::InlinedVector; -#else -class TensorShapeVector : public std::vector { - using Base = std::vector; - - public: - using Base::Base; -}; - -#endif // DISABLE_ABSEIL +// We opt to re-use a common instantiation instead of a typedef with kTensorShapeSmallBufferElementsSize +// To reduce on binary size. +using TensorShapeVector = InlinedVector; inline TensorShapeVector ToShapeVector(const gsl::span& span) { TensorShapeVector out; @@ -194,9 +169,7 @@ class TensorShape { friend struct ProviderHostImpl; // So that the shared provider interface can access Allocate }; -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif + // operator<< to nicely output to a stream std::ostream& operator<<(std::ostream& out, const TensorShape& shape); diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 831def24e4f5e..4628afbb5a702 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -80,13 +80,13 @@ static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = #ifdef ENABLE_TRAINING // Specifies a list of op types for memory footprint reduction. // The value should be a ","-delimited list of pair of -// . +// . // For example, "Gelu+Cast+:1:0,Dropout+:1:1". // A valid "subgraph string" should be one subgraph representation output by ORT graph transformations. // "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute. // "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving" // the memory. -static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer"; +static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config"; // Specifies the level for detecting subgraphs for memory footprint reduction. // The value should be an integer. The default value is 0. diff --git a/js/README.md b/js/README.md index 7e6681e6bd897..1662de6d4ac78 100644 --- a/js/README.md +++ b/js/README.md @@ -344,13 +344,13 @@ From ORT v1.13 onwards the 'full' ONNX Runtime package is used. It supports both Full build: ```sh - python tools/ci_build/github/apple/build_ios_framework.py tools/ci_build/github/apple/default_full_ios_framework_build_settings.json --config Release + python tools/ci_build/github/apple/build_apple_framework.py tools/ci_build/github/apple/default_full_apple_framework_build_settings.json --config Release ``` Reduced size build: ```sh - python tools/ci_build/github/apple/build_ios_framework.py tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json --config MinSizeRel --include_ops_by_config --enable_reduced_operator_type_support + python tools/ci_build/github/apple/build_apple_framework.py tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json --config MinSizeRel --include_ops_by_config --enable_reduced_operator_type_support ``` The build creates `Headers`, `LICENSE`, and `onnxruntime.xcframework` in `build/iOS_framework/framework_out` directory. From `framework_out` directory, create an archive file named `onnxruntime-c.zip` for a full build or `onnxruntime-mobile-c.zip` for a reduced size build and copy to `/js/react_native/local_pods` directory. diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index dd04ef3f15997..67d283b694955 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -49,8 +49,9 @@ export interface TrainingSessionHandler extends SessionHandler { feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise; + getParametersSize(trainableOnly: boolean): Promise; loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; - getContiguousParameters(trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; } /** diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index ee6d26b22b1f6..03694738387f2 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -176,12 +176,24 @@ export class TrainingSession implements TrainingSessionInterface { return this.convertHandlerReturnTypeToMapOfTensors(results); } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async getParametersSize(trainableOnly = true): Promise { + return this.handler.getParametersSize(trainableOnly); } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise { + const paramsSize = await this.getParametersSize(trainableOnly); + // checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number + // of parameters + if (array.length !== 4 * paramsSize) { + throw new Error( + 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + + 'the model. Please use getParametersSize method to check.'); + } + return this.handler.loadParametersBuffer(array, trainableOnly); + } + + async getContiguousParameters(trainableOnly = true): Promise { + return this.handler.getContiguousParameters(trainableOnly); } async release(): Promise { diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index 0967d79b33434..810ec2a8583b3 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -49,21 +50,33 @@ export interface TrainingSession { // #endregion // #region copy parameters + + /** + * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of + * the parameters) elements of all the parameters in the training state. + * + * @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true. + */ + getParametersSize(trainableOnly: boolean): Promise; + /** - * Copies from a buffer containing parameters to the TrainingSession parameters. + * Copies parameter values from the given array to the training state. Currently, only supporting models with + * parameters of type Float32. * - * @param buffer - buffer containing parameters - * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. + * @param buffer - Float32 buffer containing parameters converted to a Uint8Array. + * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. */ loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; /** - * Copies from the TrainingSession parameters to a buffer. + * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. + * Currently, only supporting models with parameters of type Float32. * - * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. - * @returns A promise that resolves to a buffer of the requested parameters. + * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters + * for which requires_grad is set to true. Default value is true. + * @returns A promise that resolves to a Float32 OnnxValue of the requested parameters. */ - getContiguousParameters(trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; // #endregion // #region release() diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index b246e19137888..00c27fe3ab034 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -22,6 +22,7 @@ Do not modify directly.* | Atanh | ai.onnx(9+) | | | Attention | com.microsoft(1+) | need implementing mask and past/present | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | +| BatchNormalization | ai.onnx(7-8,9-13,14,15+); com.ms.internal.nhwc(7-8,9-13,14,15+) | | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | | Cast | ai.onnx(6-8,9-12,13-18,19+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index bac44328d8f44..80f6e3bc11195 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -3,6 +3,7 @@ import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; import {attention, parseAttentionAttributes} from './ops/attention'; +import {batchNorm} from './ops/batch-norm'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; @@ -51,6 +52,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Attention', [attention, parseAttentionAttributes]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], + ['BatchNormalization', [batchNorm]], ['BiasAdd', [biasAdd]], ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 089e783d7e22f..3638938df7dbe 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -21,9 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; import {getActivationSnippet} from '../fuse-utils'; @@ -50,9 +49,9 @@ const conv2dCommonSnippet = const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: - return 'return w[row * wShape[3] + colIn];'; + return 'return w[row * i32(uniforms.w_shape[3]) + colIn];'; case 4: - return 'return w[row * wShape[3] / 4 + colIn];'; + return 'return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];'; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); } @@ -79,13 +78,13 @@ const conv2dCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'xShape[1]' : 'xShape[2]'; - const xWidth = isChannelsLast ? 'xShape[2]' : 'xShape[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readXSnippet = ` - let inChannels = wShape[2]; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let inChannels = i32(uniforms.w_shape[2]); + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; @@ -99,7 +98,7 @@ const conv2dCommonSnippet = // the 'same' padding type. if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { ${coordASnippet} - let xIndex = getIndexFromCoords4D(coord, xShape); + let xIndex = getIndexFromCoords4D(coord, vec4(uniforms.x_shape)); ${getXSnippet(innerElementSizeX)} } return resData;`; @@ -109,7 +108,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < dimAOuter && col < dimInner) { + if (row < uniforms.dimAOuter && col < uniforms.dimInner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -118,7 +117,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < dimInner && col < dimBOuter) { + if (row < uniforms.dimInner && col < uniforms.dimBOuter) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -143,10 +142,10 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimBOuter) + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueIn; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} ${biasSnippet(addBias)} ${applyActivation} @@ -181,7 +180,7 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0]; + const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; @@ -194,10 +193,18 @@ export const createConv2DMatMulProgramInfo = const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; const t = tensorTypeToWsglStorageType(inputs[0].dataType); - const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`, - `@group(0) @binding(1) var w: array<${isVec4 ? `vec4<${t}>` : t}>;` - ]; + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const x = + inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); @@ -207,41 +214,40 @@ export const createConv2DMatMulProgramInfo = setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? `vec4<${t}>` : t}>;`); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } - + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + programUniforms.push(...createTensorShapeVariables(outputShape)); return { name: 'Conv2DMatMul', shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, }), - getShaderSource: () => ` - ${utilFunctions} + getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${declareInputs.join('')} - @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? `vec4<${t}>` : t}>; - //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; - - const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); - const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)} const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; ${declareFunctions} ${ conv2dCommonSnippet( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 85cf7bf87f52c..d425155857e14 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; +import {ProgramInfo, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {getActivationSnippet} from '../fuse-utils'; @@ -36,16 +36,16 @@ const conv2dTransposeCommonSnippet = const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: - return 'return W[getIndexFromCoords4D(coord, wShape)];'; + return 'return w[getIndexFromCoords4D(coord, vec4(uniforms.w_shape))];'; case 4: return ` let coord1 = vec4(coordX, coordY, col + 1, rowInner); let coord2 = vec4(coordX, coordY, col + 2, rowInner); let coord3 = vec4(coordX, coordY, col + 3, rowInner); - let v0 = W[getIndexFromCoords4D(coord, wShape)]; - let v1 = W[getIndexFromCoords4D(coord1, wShape)]; - let v2 = W[getIndexFromCoords4D(coord2, wShape)]; - let v3 = W[getIndexFromCoords4D(coord3, wShape)]; + let v0 = w[getIndexFromCoords4D(coord, vec4(uniforms.w_shape))]; + let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; + let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; + let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; return vec4(v0, v1, v2, v3); `; default: @@ -81,7 +81,7 @@ const conv2dTransposeCommonSnippet = const readASnippet = ` let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; @@ -99,17 +99,17 @@ const conv2dTransposeCommonSnippet = let iXC = i32(xC); let xCh = ${col} % inChannels; ${coordASnippet} - return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; + return x[getIndexFromCoords4D(coord, vec4(uniforms.x_shape))/${innerElementSize}];`; const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimInner) { + if (row < uniforms.dimAOuter && col < uniforms.dimInner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < dimInner && col < dimBOuter) { + if (row < uniforms.dimInner && col < uniforms.dimBOuter) { ${readASnippet} } return ${type}(0.0);`; @@ -120,8 +120,8 @@ const conv2dTransposeCommonSnippet = let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; if (${ - isChannelsLast ? 'row < dimInner && col < dimBOuter' : - 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : + 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -142,13 +142,13 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimBOuter) { + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueInput; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} ${biasSnippet(addBias)} ${applyActivation} - result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; + result[getIndexFromCoords4D(coords, vec4(uniforms.result_shape))/${innerElementSize}] = value; } }`; return userCode; @@ -185,37 +185,46 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + const components = isVec4 ? 4 : 1; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); - - const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, - '@group(0) @binding(1) var W: array;' - ]; let declareFunctions = ''; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } + + programUniforms.push(...createTensorShapeVariables(outputShape)); + return { name: 'Conv2DTransposeMatMul', shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]} + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms }), - getShaderSource: () => ` - ${utilFunctions} - ${declareInputs.join('\n')} - @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? 'vec4' : 'f32'}>; + getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions('uniforms.result_strides')} + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)}; const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); - const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ attributes.kernelShape[isChannelsLast ? 2 : 3]}); const effectiveFilterDims : vec2 = filterDims + vec2( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts index 0ba48a33fbc47..6f2c0231104dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts @@ -19,13 +19,13 @@ // // modified to fit the needs of the project -export const utilFunctions = ` +export const utilFunctions = (strideStr: string) => (` fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 { return dot(coords, vec4( shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1)); } fn getOutputIndexFromCoords(coords : vec4) -> i32 { return dot(coords, vec4( - outShapeStrides.x, outShapeStrides.y, outShapeStrides.z, 1)); + i32(${strideStr}.x), i32(${strideStr}.y), i32(${strideStr}.z), 1)); } -`; +`); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 335de01c596b7..3e520571779e4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -21,8 +21,8 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,7 +112,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; @@ -322,7 +322,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -384,7 +384,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < dimAOuter && col < dimInner) + if(row < uniforms.dimAOuter && col < uniforms.dimInner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -396,7 +396,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < dimInner && col < dimBOuter) + if(row < uniforms.dimInner && col < uniforms.dimBOuter) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -406,7 +406,7 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < dimAOuter && col < dimBOuter) { + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ @@ -430,8 +430,11 @@ export const createMatmulProgramInfo = const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims); + const enableBatchUniforms = enableShapesUniforms(outerDims.length); + const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; + const batchDims = inputVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1, true); const variables = [batchDims]; const batchShapes = [outerDimsA, outerDimsB, outerDims]; const batchSize = ShapeUtil.size(outerDims); @@ -452,39 +455,81 @@ export const createMatmulProgramInfo = const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); - const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); - const output = - outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components); + + const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; + const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); + const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; + + const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; + const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); + const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; + + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); variables.push(A); variables.push(B); variables.push(output); - const inputVariables = [A, B]; + const inputVariables = [batchDims, A, B]; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + if (enableBatchUniforms) { + programUniforms.push(...createTensorShapeVariables(outerDims)); + } + if (enableAShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(aShapeTemp)); + } + if (enableBShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(bShapeTemp)); + } + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); + inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + const hasBias = inputs.length > 2; const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast); if (hasBias) { const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + const getShaderSource = (shaderHelper: ShaderHelper) => ` - const dimAOuter: i32 = ${dimAOuter}; - const dimBOuter: i32 = ${dimBOuter}; - const dimInner: i32 = ${dimInner}; - ${shaderHelper.declareVariables(...inputVariables, output)} + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)} ${activationFunction} ${declareFunctions} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} - ${batchDims.impl()}`; + `; + // TODO: turn clipMax and clipMin to uniforms. return { name: 'MatMul', - shaderCache: {hint: activationAttributes.activationCacheKey}, + shaderCache: { + hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + + `${activationAttributes.activation}` + + `${activationAttributes.clipMax}` + + `${activationAttributes.clipMin}` + + `${isVec4}` + + `${hasBias}` + + `${isChannelsLast}`, + inputDependencies + }, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]} + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts new file mode 100644 index 0000000000000..ec9da2613f406 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; + +export interface BatchNormAttributes extends AttributeWithCacheKey { + readonly epsilon: number; + readonly momentum: number; + readonly spatial: boolean; + readonly trainingMode: boolean; + readonly format: 'NHWC'|'NCHW'; + readonly outputCount: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttributes): void => { + if (!inputs || inputs.length !== 5) { + throw new Error('BatchNormalization requires 5 inputs'); + } + + const checkShapeEqual = (actual: readonly number[], expected: readonly number[], message: string) => { + const r = expected.length; + if (r !== actual.length) { + throw new Error(`${message}: num dimensions != ${r}`); + } + expected.forEach((v, i) => { + if (v !== actual[i]) { + throw new Error(`${message}: dim[${i}] do not match`); + } + }); + }; + + if (inputs[0].dims.length > 1) { + const shape = attributes.format === 'NHWC' ? + (attributes.spatial ? inputs[0].dims.slice(-1) : + inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1))) : + inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined); + checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale'); + checkShapeEqual(inputs[2].dims, shape, 'Invalid input B'); + checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean'); + checkShapeEqual(inputs[4].dims, shape, 'Invalid input var'); + } else { + checkShapeEqual(inputs[1].dims, [1], 'Invalid input scale'); + checkShapeEqual(inputs[2].dims, [1], 'Invalid input B'); + checkShapeEqual(inputs[3].dims, [1], 'Invalid input mean'); + checkShapeEqual(inputs[4].dims, [1], 'Invalid input var'); + } +}; + +const createBatchNormInferenceProgramInfo = + (inputs: readonly TensorView[], attributes: BatchNormAttributes): ProgramInfo => { + const {epsilon, spatial, format} = attributes; + const yShape = inputs[0].dims; + const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1; + const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; + const outputSize = ShapeUtil.size(yShape) / components; + // Only support uniforms for opset version >= 9 (spatial = true). + const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; + const shapeOrRank = useShapesUniforms ? yShape.length : yShape; + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents); + const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents); + const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents); + const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components); + // TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type. + // Otherwise, the shader compilation will fail. + const calcCOffset = (): string => { + let cOffset = ''; + if (spatial) { + cOffset = `let cOffset = ${ + yShape.length === 1 ? '0u' : + format === 'NHWC' ? `outputIndices[${yShape.length - 1}] / ${components}` : + 'outputIndices[1]'};`; + } else { + if (format === 'NCHW') { + cOffset = ` + ${y.indicesSet('outputIndices', '0', '0')} + let cOffset = ${y.indicesToOffset('outputIndices')};`; + } else { + // update C channel. + cOffset = `var cIndices = ${scale.type.indices}(0); + cIndices[0] = outputIndices[${yShape.length - 1}];`; + // update D1 x ... x Dn channels. + for (let i = 1; i < scale.rank; i++) { + cOffset += `cIndices[${i}] = outputIndices[${i}];`; + } + cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`; + } + } + return cOffset; + }; + const getInferenceModeShaderSource = (helper: ShaderHelper) => ` + const epsilon = ${epsilon}; + ${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)} + ${helper.mainStart()} + ${helper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + var outputIndices = ${y.offsetToIndices(`global_idx * ${components}`)}; + ${calcCOffset()} + let scale = ${scale.getByOffset('cOffset')}; + let bias = ${bias.getByOffset('cOffset')}; + let inputMean = ${inputMean.getByOffset('cOffset')}; + let inputVar = ${inputVar.getByOffset('cOffset')}; + let x = ${x.getByOffset('global_idx')}; + let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias; + ${y.setByOffset('global_idx', 'value')} + }`; + return { + name: 'BatchNormalization', + shaderCache: { + hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`, + inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined, + }, + getShaderSource: getInferenceModeShaderSource, + getRunData: () => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ + {type: 'uint32', data: outputSize}, + ...createTensorShapeVariables(yShape), + ] : + [ + {type: 'uint32', data: outputSize}, + ], + }), + }; + }; + +export const parseBatchNormAttributes = (attributes: Record): BatchNormAttributes => + createAttributeWithCacheKey(attributes as Omit); + +export const batchNorm = (context: ComputeContext, attributes: Record): void => { + const {inputs, outputCount} = context; + const updatedAttributes = parseBatchNormAttributes({...attributes, outputCount}); + if (env.webgpu.validateInputContent) { + validateInputs(inputs, updatedAttributes); + } + if (attributes.trainingMode) { + throw new Error('BatchNormalization trainingMode is not supported yet.'); + } else { + context.compute(createBatchNormInferenceProgramInfo(inputs, updatedAttributes)); + } +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index 14eefc344f3c0..a81a7a8f1df5c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; import {erfImpl} from './unary-op'; const validateInputs = (inputs: readonly TensorView[]): void => { @@ -35,6 +35,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI const output = outputVariable('output', inputs[0].dataType, outputShape, 4); const outputSize = ShapeUtil.size(outputShape) / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const getShaderSource = (shaderHelper: ShaderHelper) => ` const M_SQRT2 = sqrt(2.0); @@ -42,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl('vec4f')} + ${erfImpl(`vec4<${dataType}>`, dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 014d9d02f6f10..f7ae18998b218 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -210,6 +210,11 @@ export interface IndicesHelper { * a string representing the variable name for the strides of the input or output. */ readonly strides: string; + + /** + * representing variable with uniforms, but without binding. + */ + readonly uniformOnly: boolean; } const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { @@ -335,8 +340,8 @@ export const sumVector = (name: string, components: number) => { * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, - components: 1|2|3|4): IndicesHelper => { + (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4, + uniformOnly = false): IndicesHelper => { const useUniform = typeof shapeOrRank === 'number'; const rank = useUniform ? shapeOrRank : shapeOrRank.length; const rankIdentity = [...new Array(rank).keys()]; @@ -358,7 +363,7 @@ const createIndicesHelper = getByIndices: false, }; - const uniformPrefix = useUniform ? 'uniforms.' : ''; + const uniformPrefix = useUniform || uniformOnly ? 'uniforms.' : ''; const shape = `${uniformPrefix}${name}_shape`; const strides = `${uniformPrefix}${name}_strides`; let o2iSnippet = ''; @@ -616,7 +621,8 @@ const createIndicesHelper = name, strides, shape, - rank + rank, + uniformOnly }; }; @@ -630,8 +636,8 @@ const createIndicesHelper = * @returns an IndicesHelper for the input. */ export const inputVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, true, components); + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, uniformOnly = false): + IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, uniformOnly); /** * Create a IndicesHelper for an output. @@ -734,7 +740,7 @@ class ShaderHelperImpl implements ShaderHelper { `; } - private declareVariable(variable: IndicesHelper, bindingIndex: number): string { + private declareVariable(variable: IndicesHelper, bindingIndex = -1): string { this.indicesHelpers.push(variable); if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { @@ -744,13 +750,24 @@ class ShaderHelperImpl implements ShaderHelper { this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); } } + if (variable.uniformOnly) { + return ''; + } const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; } declareVariables(...variables: IndicesHelper[]): string { - return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n'); + return variables + .map(v => { + if (v.uniformOnly === true) { + return this.declareVariable(v); + } else { + return this.declareVariable(v, this.variableIndex++); + } + }) + .join('\n'); } registerUniform(name: string, type: string): ShaderHelper { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 5680af4787b6a..d998013352d77 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -47,14 +47,18 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const outputSize = ShapeUtil.size(outputShape); const dataType = inputs[0].dataType; - const input = inputVariable('input', dataType, inputShape); - const output = outputVariable('output', dataType, outputShape); + const enableInputShapeUniform = enableShapesUniforms(inputShape.length); + const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; + const input = inputVariable('input', dataType, inputShapeOrRank); + const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; var inputIndices: ${input.type.indices}; for (var i = 0; i < ${inputShape.length}; i++) { @@ -68,13 +72,21 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => } ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + if (enableInputShapeUniform) { + programUniforms.push(...createTensorShapeVariables(inputShape)); + } + if (enableOutputShapeUniform) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } return { name: 'Expand', - shaderCache: {hint: `${outputShape}`}, + shaderCache: {hint: `${outputShape}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 9869561a36251..973a607f9377e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -105,50 +105,51 @@ const validateInputs = } }; -const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string => - 'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\ - lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' + +const getOriginalCoordinateFromResizedCoordinate = + (coordinateTransferMode: CoordinateTransformMode, dType: string): string => + `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType}, + lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + (() => { - switch (coordinateTransferMode) { - case 'asymmetric': - return 'return xResized / xScale;'; - case 'pytorch_half_pixel': - return 'if (lengthResized > 1) { \ + switch (coordinateTransferMode) { + case 'asymmetric': + return 'return xResized / xScale;'; + case 'pytorch_half_pixel': + return 'if (lengthResized > 1) { \ return (xResized + 0.5) / xScale - 0.5; \ } else { \ return 0.0; \ }'; - case 'tf_half_pixel_for_nn': - return 'return (xResized + 0.5) / xScale;'; - case 'align_corners': - return 'if (lengthResized == 1) { \ + case 'tf_half_pixel_for_nn': + return 'return (xResized + 0.5) / xScale;'; + case 'align_corners': + return 'if (lengthResized == 1) { \ return 0.0; \ } else { \ return xResized * (lengthOriginal - 1) / (lengthResized - 1); \ }'; - case 'tf_crop_and_resize': - return 'if (lengthResized > 1) { \ + case 'tf_crop_and_resize': + return `if (lengthResized > 1) { \ return roiStart * (lengthOriginal - 1) + \ (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \ } else { \ - return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \ - }'; - case 'half_pixel_symmetric': - return [ - 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', - 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', - 'return offset + ((xResized + 0.5) / xScale) - 0.5;' - ].join('\n'); - case 'half_pixel': - return 'return ((xResized + 0.5) / xScale) - 0.5;'; - default: - throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); - } - })() + + return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \ + }`; + case 'half_pixel_symmetric': + return [ + 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', + 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', + 'return offset + ((xResized + 0.5) / xScale) - 0.5;' + ].join('\n'); + case 'half_pixel': + return 'return ((xResized + 0.5) / xScale) - 0.5;'; + default: + throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); + } + })() + '}'; -const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string => - 'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => { +const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string => + `fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => { switch (nearestMode) { case 'round_prefer_ceil': return 'if (fract(xOriginal) == 0.5) { \ @@ -246,20 +247,21 @@ const adjustOutputShape = (inputShape: readonly number[], scales: number[], attr const calculateOriginalIndicesFromOutputIndices = (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[]): string => ` - fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array { + fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${ + output.type.value}, ${outputShape.length}> { const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); - var originalIndices: array; + const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')}); + const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')}); + var originalIndices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; if (scales[i] == 1.0) { - originalIndices[i] = f32(outputIndex); + originalIndices[i] = ${output.type.value}(outputIndex); } else { - originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); + originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i], + ${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${ + inputShape.length}]); } } return originalIndices; @@ -271,8 +273,8 @@ const calculateInputIndicesFromOutputIndices = fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); + const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')}); + const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')}); var inputIndices: ${input.type.indices}; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; @@ -280,12 +282,13 @@ const calculateInputIndicesFromOutputIndices = if (scales[i] == 1.0) { inputIndex = outputIndex; } else { - var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); - if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) { + var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i], + ${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${ + inputShape.length}]); + if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) { if (original_idx < 0) { inputIndex = 0; - } else if (original_idx > (f32(inputShape[i]) - 1)) { + } else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) { inputIndex = inputShape[i] - 1; } else { inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1)); @@ -316,8 +319,9 @@ const bilinearInterpolation = useExtrapolation: boolean, extrapolationValue: number): string => { const [batchIdx, heightIdx, widthIdx, channelIdx] = inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); + const dType = input.type.value; return ` - fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 { + fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { var inputIndices: ${input.type.indices}; inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1)); inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1)); @@ -328,10 +332,10 @@ const bilinearInterpolation = return input[${input.indicesToOffset('inputIndices')}]; } - fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> f32 { + fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices); - var row:f32 = originalIndices[${heightIdx}]; - var col:f32 = originalIndices[${widthIdx}]; + var row:${dType} = originalIndices[${heightIdx}]; + var col:${dType} = originalIndices[${widthIdx}]; if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${ inputShape[widthIdx]} - 1)) { return ${extrapolationValue}; @@ -348,14 +352,14 @@ const bilinearInterpolation = channel = u32(originalIndices[${channelIdx}]); batch = u32(originalIndices[${batchIdx}]); } - var x11: f32 = getInputValue(batch, channel, row1, col1); - var x12: f32 = getInputValue(batch, channel, row1, col2); - var x21: f32 = getInputValue(batch, channel, row2, col1); - var x22: f32 = getInputValue(batch, channel, row2, col2); - var dx1: f32 = row - f32(row1); - var dx2: f32 = f32(row2 ) - row; - var dy1 = col - f32(col1); - var dy2 = f32(col2) - col; + var x11: ${dType} = getInputValue(batch, channel, row1, col1); + var x12: ${dType} = getInputValue(batch, channel, row1, col2); + var x21: ${dType} = getInputValue(batch, channel, row2, col1); + var x22: ${dType} = getInputValue(batch, channel, row2, col2); + var dx1: ${dType} = row - ${dType}(row1); + var dx2: ${dType} = ${dType}(row2) - row; + var dy1 = col - ${dType}(col1); + var dy2 = ${dType}(col2) - col; return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); }`; }; @@ -365,24 +369,24 @@ const bicubicInterpolation = scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => { const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2]; - + const dType = input.type.value; const createCubicInterpolationFunction = (idx: number): string => { const direction = idx === heightIdx ? 'row' : 'col'; return ` fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${ - output.type.indices}) -> f32 { + output.type.indices}) -> ${dType} { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`}; - var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]}, - f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); - var fractOriginalIdx: f32 = originalIdx - floor(originalIdx); + var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]}, + ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); + var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx); var coefs = getCubicInterpolationCoefs(fractOriginalIdx); if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) { return ${extrapolationValue}; } - var data: array = array(0.0, 0.0, 0.0, 0.0); + var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); for (var i: i32 = -1; i < 3; i++) { - var ${direction}: f32 = originalIdx + f32(i); + var ${direction}: ${dType} = originalIdx + ${dType}(i); if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) { if (${excludeOutside}) { coefs[i + 1] = 0.0; @@ -405,12 +409,12 @@ const bicubicInterpolation = return ` ${createCubicInterpolationFunction(heightIdx)}; ${createCubicInterpolationFunction(widthIdx)}; - fn getCubicInterpolationCoefs(s: f32) -> array { + fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> { var absS = abs(s); - var coeffs: array = array(0.0, 0.0, 0.0, 0.0); - var oneMinusAbsS: f32 = 1.0 - absS; - var twoMinusAbsS: f32 = 2.0 - absS; - var onePlusAbsS: f32 = 1.0 + absS; + var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); + var oneMinusAbsS: ${dType} = 1.0 - absS; + var twoMinusAbsS: ${dType} = 2.0 - absS; + var onePlusAbsS: ${dType} = 1.0 + absS; coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${ cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA}; coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1; @@ -420,12 +424,12 @@ const bicubicInterpolation = return coeffs; } - fn cubicInterpolation1D(x: array, coefs: array) -> f32 { - var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3]; + fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} { + var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3]; return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum; } - fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> f32 { + fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { var inputIndices: ${input.type.indices} = outputIndices; return colCubicInterpolation(inputIndices, outputIndices); } @@ -451,15 +455,16 @@ const createResizeProgramInfo = const outputSize = ShapeUtil.size(outputShape); const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; + const dataType = input.type.value; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${noScale ? '' : ` - ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)}; + ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)}; ${(() => { switch (attributes.mode) { case 'nearest': return ` ${checkInputIndices(input, inputShape)}; - ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)}; + ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)}; ${ calculateInputIndicesFromOutputIndices( input, output, inputShape, outputShape, scales, roi, useExtrapolation)}; diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index 09d91591128d1..7de3f4dc2c89e 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,20 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; +import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } private sessionId: number; private checkpointId: number; @@ -124,6 +118,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); } + async getParametersSize(trainableOnly: boolean): Promise { + return getParametersSize(this.sessionId, trainableOnly); + } + + async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { + await loadParametersBuffer(this.sessionId, array, trainableOnly); + } + async getContiguousParameters(trainableOnly: boolean): Promise { + const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); + return decodeTensorMetadata(tensorResult); + } + async dispose(): Promise { return releaseTrainingSessionAndCheckpoint( this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index a35d285346db4..c0a4235113148 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -6,7 +6,7 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; @@ -16,6 +16,22 @@ const NO_TRAIN_FUNCS_MSG = 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; +/** + * Runs the checkLastError function which will throw an error, if the provided error code matches the specified + * pattern for an error code. + * @param errCode number to evaluated for if it's an error + * @param message message to pass into checkLastError + * @param checkNeqZero when true, treats not equal to zero as an error. + * When false, treats equal to zero as an error. + */ +const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { + if (checkNeqZero && errCode !== 0) { + checkLastError(message); + } else if (!checkNeqZero && errCode === 0) { + checkLastError(message); + } +}; + export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -29,9 +45,7 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n throw new Error(NO_TRAIN_FUNCS_MSG); } - if (checkpointHandle === 0) { - checkLastError('Error occurred when trying to create a CheckpointState.'); - } + ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); return checkpointHandle; } catch (e) { if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { @@ -52,9 +66,7 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea if (wasm._OrtTrainingGetModelInputOutputCount) { const errorCode = wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); - if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); - } + ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); @@ -74,9 +86,7 @@ const getModelInputOutputNamesLoop = for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetModelInputOutputName) { const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - if (name === 0) { - checkLastError('Can\'t get input or output name'); - } + ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); namesUTF8Encoded.push(name); names.push(wasm.UTF8ToString(name)); @@ -122,9 +132,7 @@ export const createTrainingSessionHandle = throw new Error(NO_TRAIN_FUNCS_MSG); } - if (trainingSessionHandle === 0) { - checkLastError('Error occurred when trying to create a TrainingSession.'); - } + ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = getTrainingModelInputOutputNames(trainingSessionHandle); @@ -213,9 +221,8 @@ const moveOutputToTensorMetadataArr = try { const errorCode = wasm._OrtGetTensorData( tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); - } + ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); + let tensorDataIndex = tensorDataOffset / 4; const dataType = wasm.HEAPU32[tensorDataIndex++]; dataOffset = wasm.HEAPU32[tensorDataIndex++]; @@ -290,10 +297,7 @@ export const runTrainStep = async( if (wasm._OrtTrainingRunTrainStep) { const errorCode = wasm._OrtTrainingRunTrainStep( trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); - - if (errorCode !== 0) { - checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } + ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -313,6 +317,128 @@ export const runTrainStep = async( } }; +export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + try { + const sizeOffset = wasm.stackAlloc(4); + if (wasm._OrtTrainingGetParametersSize) { + const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); + ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size'); + + return wasm.HEAP32[sizeOffset / 4]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +export const getContiguousParameters = + async(trainingSessionId: number, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + const parametersSize = getParametersSize(trainingSessionId, trainableOnly); + let tensor = 0; + + // allocates a buffer of the correct size on the WASM heap + const paramsByteLength = 4 * parametersSize; + const paramsOffset = wasm._malloc(paramsByteLength); + + // handles the dimensions-related createTensor parameters + const dims = [parametersSize]; + + const dimsOffset = wasm.stackAlloc(4); + const dimsIndex = dimsOffset / 4; + wasm.HEAP32[dimsIndex] = parametersSize; + + try { + // wraps allocated array in a tensor + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError( + tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false); + + if (wasm._OrtTrainingCopyParametersToBuffer) { + const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.'); + + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); + const data = new typedArrayConstructor(parametersSize); + const output: TensorMetadata[] = []; + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); + output.push([tensorTypeAsString, dims, data, locationAsString]); + if (output.length !== 1) { + throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of + one, got ${output.length}`); + } else { + return output[0]; + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm._free(paramsOffset); + wasm._free(dimsOffset); + wasm.stackRestore(stack); + } +}; + +export const loadParametersBuffer = + async(trainingSessionId: number, buffer: Uint8Array, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + // allocates & copies JavaScript buffer to WASM heap + const bufferByteLength = buffer.length; + const bufferCount = bufferByteLength / 4; + const bufferOffset = wasm._malloc(bufferByteLength); + wasm.HEAPU8.set(buffer, bufferOffset); + + // allocates and handles moving dimensions information to WASM memory + const dimsOffset = wasm.stackAlloc(4); + wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsLength = 1; + let tensor = 0; + + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); + + if (wasm._OrtTrainingCopyParametersFromBuffer) { + const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm.stackRestore(stack); + wasm._free(bufferOffset); + wasm._free(dimsOffset); + } +}; + export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { diff --git a/js/web/test/data/ops/batch-norm.jsonc b/js/web/test/data/ops/batch-norm.jsonc new file mode 100644 index 0000000000000..4ea16f290dc8f --- /dev/null +++ b/js/web/test/data/ops/batch-norm.jsonc @@ -0,0 +1,446 @@ +[ + { + "name": "BatchNormalization with no attributes", + "operator": "BatchNormalization", + "attributes": [], + "cases": [ + { + "name": "T[64]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425 + ], + "dims": [64], + "type": "float32" + }, + { + "data": [0.241661], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189 + ], + "dims": [64], + "type": "float32" + } + ] + }, + { + "name": "T[2,3,4,4,4]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425, 0.168795, 0.740422, -0.377683, 0.432598, -2.07414, -2.85251, 0.273531, + 0.0532606, 1.31052, -0.769382, 0.9976, 0.850536, -1.53812, -0.00496016, 0.931242, 0.0517056, -0.497829, + 0.275869, 0.860001, 1.23747, 0.179686, 1.5914, 0.740327, 0.798208, 2.12478, 1.74205, -0.322054, + -0.0112451, 0.204525, -0.431252, -1.3114, 0.186204, 0.780569, -1.42994, 1.63344, -0.00839034, -0.187035, + 1.8406, 1.32053, -0.636963, 0.408944, -1.50846, -1.2076, -0.129118, -0.0441307, 1.47558, 1.07251, 1.05295, + -0.420297, -1.13402, -0.524053, 3.20754, -0.588935, -0.527549, 0.591928, -1.10529, 0.520412, 0.19404, + -1.21229, -0.399594, -0.280935, -0.363324, -0.00804771, 1.43102, -0.523222, 1.17608, -0.53195, 0.914993, + 2.69308, -0.517211, 0.472273, -0.464725, -0.929768, -0.631145, 0.919709, -0.27391, 1.76689, 0.894897, + 0.235798, 1.2544, 0.858985, -0.139707, 0.354544, 0.200878, 0.353255, 0.0722632, -1.56074, 1.03685, + 1.73434, 0.193269, -0.864609, 0.842739, -0.372717, 0.584484, 0.16315, 1.60674, -0.0611289, -1.24544, + 1.33361, -0.961942, -0.15732, -0.348637, 0.361842, 0.7386, 0.517256, 1.20406, -2.07277, -1.01983, -1.9163, + 0.239934, 0.177979, 0.464564, 0.988822, 0.284607, -1.56099, -0.429143, 0.111043, -0.0853688, -0.319176, + -0.279777, 0.520971, -1.078, -0.670242, 0.065652, 0.468538, -0.825062, 0.370068, 1.68751, -1.16928, + -0.411782, 1.61624, -0.973004, 2.64703, -0.220014, -1.43954, -0.018692, 1.34982, -0.95197, -1.72586, + 1.32725, 0.280984, 0.00847463, 0.512869, 0.0378154, 0.13898, 0.35758, -0.084558, 1.04045, -1.79933, + 1.3002, 0.390457, 1.22267, 0.959344, -0.964296, -0.0935597, 0.288953, -0.158046, 0.532672, -0.500988, + 0.25187, -2.14384, -0.633315, 1.24612, -1.41525, 0.36494, -0.00714732, -0.608963, 0.508496, 0.995365, + 1.21159, -0.169055, -0.968783, 1.52779, -0.082381, 2.2049, 0.928655, 0.120245, 0.911429, -0.885258, + -1.2072, 0.770694, 2.36621, 1.08456, -1.60069, 0.0345025, 0.359559, -0.785411, 0.466532, -0.78543, + 0.024879, 1.59337, 1.13718, -1.27073, -0.263788, -1.7702, 0.203263, 1.34631, 1.11914, -2.04911, -0.804137, + 0.466763, 2.18386, 1.4689, 0.898297, -0.648948, 0.252202, 1.12501, -0.204563, 0.124608, 0.377214, + 0.894327, -0.249118, 0.709188, 0.999397, -1.4079, 0.193873, 0.657753, -0.709732, 1.09897, -0.145793, + 0.779199, 0.88378, -1.2676, 1.15709, 0.62295, -0.370894, -0.103268, -1.55949, -0.470747, 0.100394, + 0.422334, -0.0685312, -0.434488, -0.568974, -0.256987, 2.01276, -0.923322, -0.613144, 1.50676, 0.65756, + 1.20524, 1.10395, -0.975241, 2.44035, 1.08276, 0.330393, -0.508918, -1.25545, 0.189815, -0.156263, + -0.960866, 1.0859, -0.674478, 2.76743, 1.21399, 1.71666, -1.73198, -1.1062, 0.951285, -0.713336, 1.61586, + 1.96514, 0.002603, 0.0953297, 0.949256, -1.76552, 0.372816, -0.781229, 1.50532, 1.28462, 1.31116, + 0.731908, 1.54835, 0.371081, 0.409244, -0.106938, -1.79396, -1.61198, -0.80869, -1.10381, 1.1872, + -0.832439, 0.0755941, -1.09553, 0.960059, 1.44252, -0.196482, -1.07364, 0.165547, 0.630078, 1.56569, + -0.669592, 1.15974, 0.0953399, -0.202313, 0.812631, -0.318567, -0.16644, 0.887062, -0.0264821, -0.740725, + 0.0797577, -1.1037, 0.90236, 1.13427, 0.364186, -2.01043, -0.415748, 0.116046, 0.369949, 0.317886, + 0.530332, 1.48341, 0.74666, -1.64142, 0.22569, 1.18015, 1.31827, -1.33904, -0.101125 + ], + "dims": [2, 3, 4, 4, 4], + "type": "float32" + }, + { + "data": [0.241661, 0.960798, 0.474727], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189, 0.162177, 0.711393, -0.362876, 0.415637, + -1.99282, -2.74067, 0.262807, 0.0511725, 1.25914, -0.739217, 0.958488, 0.817189, -1.47782, -0.00476569, + 0.894731, 0.0496784, -0.478311, 0.265053, 0.826283, 1.18895, 0.172641, 1.52901, 0.711301, 0.766913, + 2.04147, 1.67375, -0.309427, -0.0108042, 0.196507, -0.414344, -1.25999, 0.178903, 0.749965, -1.37387, + 1.5694, -0.00806138, -0.179702, 1.76844, 1.26875, -0.61199, 0.392911, -1.44932, -1.16025, -0.124055, + -0.0424004, 1.41773, 1.03046, 1.01167, -0.403818, -1.08956, -0.503507, 3.08178, -0.565845, -0.506866, + 0.56872, -1.06196, 0.500008, 0.186433, -1.16476, -0.383928, -0.269921, -0.349079, -0.00773219, 1.37492, + -0.248386, 0.558316, -0.25253, 0.43437, 1.27847, -0.245533, 0.2242, -0.220617, -0.441384, -0.29962, + 0.436609, -0.130032, 0.838785, 0.424829, 0.111939, 0.595496, 0.407781, -0.0663221, 0.168311, 0.0953618, + 0.167699, 0.0343051, -0.74092, 0.492219, 0.823334, 0.0917494, -0.410451, 0.400069, -0.176938, 0.277469, + 0.0774512, 0.762761, -0.0290194, -0.59124, 0.6331, -0.456657, -0.0746837, -0.165507, 0.171775, 0.350631, + 0.245554, 0.571595, -0.983996, -0.484139, -0.909715, 0.113902, 0.0844908, 0.22054, 0.469418, 0.13511, + -0.741041, -0.203725, 0.0527148, -0.0405267, -0.151521, -0.132817, 0.247318, -0.511752, -0.31818, + 0.0311666, 0.222426, -0.391677, 0.17568, 0.801104, -0.282569, -0.0995112, 0.39058, -0.235136, 0.639682, + -0.0531687, -0.347878, -0.0045171, 0.326198, -0.230053, -0.41707, 0.320744, 0.0679025, 0.00204798, + 0.12394, 0.00913847, 0.0335859, 0.0864127, -0.0204343, 0.251436, -0.434827, 0.314206, 0.0943579, 0.295471, + 0.231835, -0.233032, -0.0226096, 0.0698283, -0.0381934, 0.128725, -0.121069, 0.060867, -0.51808, + -0.153047, 0.301137, -0.342009, 0.0881915, -0.00172722, -0.147162, 0.122883, 0.24054, 0.292792, + -0.0408538, -0.234116, 0.369206, -0.0199082, 0.532835, 0.224419, 0.0290583, 0.220256, -0.213931, + -0.291733, 0.186246, 0.571817, 0.262095, -0.386822, 0.00833788, 0.086891, -0.189802, 0.112742, -0.189807, + 0.00601226, 0.385054, 0.274811, -1.22091, -0.253445, -1.7008, 0.195294, 1.29353, 1.07526, -1.96877, + -0.772609, 0.448463, 2.09824, 1.4113, 0.863078, -0.623505, 0.242314, 1.0809, -0.196543, 0.119722, + 0.362425, 0.859263, -0.239351, 0.681383, 0.960214, -1.3527, 0.186272, 0.631964, -0.681905, 1.05588, + -0.140077, 0.748649, 0.84913, -1.2179, 1.11172, 0.598526, -0.356353, -0.099219, -1.49835, -0.452291, + 0.0964582, 0.405776, -0.0658444, -0.417454, -0.546667, -0.246911, 1.93385, -0.887121, -0.589104, 1.44769, + 0.631779, 1.15798, 1.06067, -0.937005, 2.34467, 1.04031, 0.31744, -0.488965, -1.20623, 0.182373, + -0.150136, -0.923194, 1.04332, -0.648034, 2.65893, 1.1664, 1.64935, -0.822216, -0.525139, 0.451599, + -0.338638, 0.767087, 0.932899, 0.00123571, 0.0452554, 0.450635, -0.838136, 0.176985, -0.370868, 0.714614, + 0.60984, 0.622438, 0.347455, 0.73504, 0.176161, 0.194278, -0.0507662, -0.851639, -0.765246, -0.383905, + -0.524005, 0.563593, -0.395179, 0.0358864, -0.520076, 0.455763, 0.684801, -0.093275, -0.509682, 0.0785892, + 0.299113, 0.743272, -0.317872, 0.550556, 0.0452602, -0.0960432, 0.385776, -0.151232, -0.079013, 0.42111, + -0.0125717, -0.35164, 0.0378629, -0.523955, 0.428372, 0.538468, 0.172888, -0.954402, -0.197366, 0.0550898, + 0.175624, 0.150908, 0.251761, 0.704209, 0.354458, -0.779221, 0.107141, 0.560244, 0.625814, -0.635675, + -0.0480064 + ], + "dims": [2, 3, 4, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization with no attributes - NHWC", + "operator": "BatchNormalization", + "opset": { "domain": "com.ms.internal.nhwc", "version": 12 }, + "attributes": [], + "cases": [ + { + "name": "T[64]", + "inputs": [ + { + "data": [ + 2.02384, -0.935186, 0.488569, -0.513934, -1.27082, -0.131913, -1.806, -0.37904, 0.667796, -1.14826, + 1.2522, 0.0300339, 2.4758, 1.55511, 0.385341, 1.46645, -1.09355, -2.56309, 0.976015, -1.47036, 0.89486, + 0.580989, -1.12418, -0.339189, 1.3314, 0.418893, -0.301401, -1.2983, -0.839063, 0.170261, 1.15486, + -0.255735, -0.589851, -0.416289, -0.952648, -0.360487, 0.253287, 0.437195, 0.32023, 0.209606, -0.279519, + -0.546527, 0.265286, -1.07383, -1.65879, 1.1222, 0.946612, 0.822549, 0.64689, -0.292639, -0.73995, + -0.694949, 1.33899, -0.0652476, 1.61791, 1.49692, -0.761145, -0.201874, -1.15431, -1.83111, -0.705267, + -0.143026, -0.129819, -0.799425 + ], + "dims": [64], + "type": "float32" + }, + { + "data": [0.241661], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [0], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, -0.225997, 0.118068, -0.124197, -0.307105, -0.031878, -0.436439, -0.0915989, 0.16138, -0.277489, + 0.302606, 0.007258, 0.598301, 0.375807, 0.0931215, 0.354382, -0.264267, -0.619395, 0.235864, -0.355328, + 0.216252, 0.140402, -0.271669, -0.0819684, 0.321747, 0.10123, -0.0728365, -0.313746, -0.202768, 0.0411454, + 0.279085, -0.0618009, -0.142543, -0.1006, -0.230217, -0.0871152, 0.0612094, 0.105652, 0.0773867, + 0.0506533, -0.0675486, -0.132074, 0.064109, -0.259501, -0.400863, 0.271191, 0.228758, 0.198777, 0.156327, + -0.0707191, -0.178816, -0.167941, 0.323581, -0.0157677, 0.390985, 0.361745, -0.183938, -0.0487849, + -0.27895, -0.442507, -0.170435, -0.0345637, -0.031372, -0.193189 + ], + "dims": [64], + "type": "float32" + } + ] + }, + { + "name": "T[2,4,4,4,3]", + "inputs": [ + { + "data": [ + 2.02384, 0.168795, -0.523222, -0.935186, 0.740422, 1.17608, 0.488569, -0.377683, -0.53195, -0.513934, + 0.432598, 0.914993, -1.27082, -2.07414, 2.69308, -0.131913, -2.85251, -0.517211, -1.806, 0.273531, + 0.472273, -0.37904, 0.0532606, -0.464725, 0.667796, 1.31052, -0.929768, -1.14826, -0.769382, -0.631145, + 1.2522, 0.9976, 0.919709, 0.0300339, 0.850536, -0.27391, 2.4758, -1.53812, 1.76689, 1.55511, -0.00496016, + 0.894897, 0.385341, 0.931242, 0.235798, 1.46645, 0.0517056, 1.2544, -1.09355, -0.497829, 0.858985, + -2.56309, 0.275869, -0.139707, 0.976015, 0.860001, 0.354544, -1.47036, 1.23747, 0.200878, 0.89486, + 0.179686, 0.353255, 0.580989, 1.5914, 0.0722632, -1.12418, 0.740327, -1.56074, -0.339189, 0.798208, + 1.03685, 1.3314, 2.12478, 1.73434, 0.418893, 1.74205, 0.193269, -0.301401, -0.322054, -0.864609, -1.2983, + -0.0112451, 0.842739, -0.839063, 0.204525, -0.372717, 0.170261, -0.431252, 0.584484, 1.15486, -1.3114, + 0.16315, -0.255735, 0.186204, 1.60674, -0.589851, 0.780569, -0.0611289, -0.416289, -1.42994, -1.24544, + -0.952648, 1.63344, 1.33361, -0.360487, -0.00839034, -0.961942, 0.253287, -0.187035, -0.15732, 0.437195, + 1.8406, -0.348637, 0.32023, 1.32053, 0.361842, 0.209606, -0.636963, 0.7386, -0.279519, 0.408944, 0.517256, + -0.546527, -1.50846, 1.20406, 0.265286, -1.2076, -2.07277, -1.07383, -0.129118, -1.01983, -1.65879, + -0.0441307, -1.9163, 1.1222, 1.47558, 0.239934, 0.946612, 1.07251, 0.177979, 0.822549, 1.05295, 0.464564, + 0.64689, -0.420297, 0.988822, -0.292639, -1.13402, 0.284607, -0.73995, -0.524053, -1.56099, -0.694949, + 3.20754, -0.429143, 1.33899, -0.588935, 0.111043, -0.0652476, -0.527549, -0.0853688, 1.61791, 0.591928, + -0.319176, 1.49692, -1.10529, -0.279777, -0.761145, 0.520412, 0.520971, -0.201874, 0.19404, -1.078, + -1.15431, -1.21229, -0.670242, -1.83111, -0.399594, 0.065652, -0.705267, -0.280935, 0.468538, -0.143026, + -0.363324, -0.825062, -0.129819, -0.00804771, 0.370068, -0.799425, 1.43102, 1.68751, -1.16928, -1.27073, + -1.73198, -0.411782, -0.263788, -1.1062, 1.61624, -1.7702, 0.951285, -0.973004, 0.203263, -0.713336, + 2.64703, 1.34631, 1.61586, -0.220014, 1.11914, 1.96514, -1.43954, -2.04911, 0.002603, -0.018692, + -0.804137, 0.0953297, 1.34982, 0.466763, 0.949256, -0.95197, 2.18386, -1.76552, -1.72586, 1.4689, + 0.372816, 1.32725, 0.898297, -0.781229, 0.280984, -0.648948, 1.50532, 0.00847463, 0.252202, 1.28462, + 0.512869, 1.12501, 1.31116, 0.0378154, -0.204563, 0.731908, 0.13898, 0.124608, 1.54835, 0.35758, 0.377214, + 0.371081, -0.084558, 0.894327, 0.409244, 1.04045, -0.249118, -0.106938, -1.79933, 0.709188, -1.79396, + 1.3002, 0.999397, -1.61198, 0.390457, -1.4079, -0.80869, 1.22267, 0.193873, -1.10381, 0.959344, 0.657753, + 1.1872, -0.964296, -0.709732, -0.832439, -0.0935597, 1.09897, 0.0755941, 0.288953, -0.145793, -1.09553, + -0.158046, 0.779199, 0.960059, 0.532672, 0.88378, 1.44252, -0.500988, -1.2676, -0.196482, 0.25187, + 1.15709, -1.07364, -2.14384, 0.62295, 0.165547, -0.633315, -0.370894, 0.630078, 1.24612, -0.103268, + 1.56569, -1.41525, -1.55949, -0.669592, 0.36494, -0.470747, 1.15974, -0.00714732, 0.100394, 0.0953399, + -0.608963, 0.422334, -0.202313, 0.508496, -0.0685312, 0.812631, 0.995365, -0.434488, -0.318567, 1.21159, + -0.568974, -0.16644, -0.169055, -0.256987, 0.887062, -0.968783, 2.01276, -0.0264821, 1.52779, -0.923322, + -0.740725, -0.082381, -0.613144, 0.0797577, 2.2049, 1.50676, -1.1037, 0.928655, 0.65756, 0.90236, + 0.120245, 1.20524, 1.13427, 0.911429, 1.10395, 0.364186, -0.885258, -0.975241, -2.01043, -1.2072, 2.44035, + -0.415748, 0.770694, 1.08276, 0.116046, 2.36621, 0.330393, 0.369949, 1.08456, -0.508918, 0.317886, + -1.60069, -1.25545, 0.530332, 0.0345025, 0.189815, 1.48341, 0.359559, -0.156263, 0.74666, -0.785411, + -0.960866, -1.64142, 0.466532, 1.0859, 0.22569, -0.78543, -0.674478, 1.18015, 0.024879, 2.76743, 1.31827, + 1.59337, 1.21399, -1.33904, 1.13718, 1.71666, -0.101125 + ], + "dims": [2, 4, 4, 4, 3], + "type": "float32" + }, + { + "data": [0.241661, 0.960798, 0.474727], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [0, 0, 0], + "dims": [3], + "type": "float32" + }, + { + "data": [1, 1, 1], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.489082, 0.162177, -0.248386, -0.225997, 0.711393, 0.558316, 0.118068, -0.362876, -0.25253, -0.124197, + 0.415637, 0.43437, -0.307105, -1.99282, 1.27847, -0.031878, -2.74067, -0.245533, -0.436439, 0.262807, + 0.2242, -0.0915989, 0.0511725, -0.220617, 0.16138, 1.25914, -0.441384, -0.277489, -0.739217, -0.29962, + 0.302606, 0.958488, 0.436609, 0.007258, 0.817189, -0.130032, 0.598301, -1.47782, 0.838785, 0.375807, + -0.00476569, 0.424829, 0.0931215, 0.894731, 0.111939, 0.354382, 0.0496784, 0.595496, -0.264267, -0.478311, + 0.407781, -0.619395, 0.265053, -0.0663221, 0.235864, 0.826283, 0.168311, -0.355328, 1.18895, 0.0953618, + 0.216252, 0.172641, 0.167699, 0.140402, 1.52901, 0.0343051, -0.271669, 0.711301, -0.74092, -0.0819684, + 0.766913, 0.492219, 0.321747, 2.04147, 0.823334, 0.10123, 1.67375, 0.0917494, -0.0728365, -0.309427, + -0.410451, -0.313746, -0.0108042, 0.400069, -0.202768, 0.196507, -0.176938, 0.0411454, -0.414344, + 0.277469, 0.279085, -1.25999, 0.0774512, -0.0618009, 0.178903, 0.762761, -0.142543, 0.749965, -0.0290194, + -0.1006, -1.37387, -0.59124, -0.230217, 1.5694, 0.6331, -0.0871152, -0.00806138, -0.456657, 0.0612094, + -0.179702, -0.0746837, 0.105652, 1.76844, -0.165507, 0.0773867, 1.26875, 0.171775, 0.0506533, -0.61199, + 0.350631, -0.0675486, 0.392911, 0.245554, -0.132074, -1.44932, 0.571595, 0.064109, -1.16025, -0.983996, + -0.259501, -0.124055, -0.484139, -0.400863, -0.0424004, -0.909715, 0.271191, 1.41773, 0.113902, 0.228758, + 1.03046, 0.0844908, 0.198777, 1.01167, 0.22054, 0.156327, -0.403818, 0.469418, -0.0707191, -1.08956, + 0.13511, -0.178816, -0.503507, -0.741041, -0.167941, 3.08178, -0.203725, 0.323581, -0.565845, 0.0527148, + -0.0157677, -0.506866, -0.0405267, 0.390985, 0.56872, -0.151521, 0.361745, -1.06196, -0.132817, -0.183938, + 0.500008, 0.247318, -0.0487849, 0.186433, -0.511752, -0.27895, -1.16476, -0.31818, -0.442507, -0.383928, + 0.0311666, -0.170435, -0.269921, 0.222426, -0.0345637, -0.349079, -0.391677, -0.031372, -0.00773219, + 0.17568, -0.193189, 1.37492, 0.801104, -0.282569, -1.22091, -0.822216, -0.0995112, -0.253445, -0.525139, + 0.39058, -1.7008, 0.451599, -0.235136, 0.195294, -0.338638, 0.639682, 1.29353, 0.767087, -0.0531687, + 1.07526, 0.932899, -0.347878, -1.96877, 0.00123571, -0.0045171, -0.772609, 0.0452554, 0.326198, 0.448463, + 0.450635, -0.230053, 2.09824, -0.838136, -0.41707, 1.4113, 0.176985, 0.320744, 0.863078, -0.370868, + 0.0679025, -0.623505, 0.714614, 0.00204798, 0.242314, 0.60984, 0.12394, 1.0809, 0.622438, 0.00913847, + -0.196543, 0.347455, 0.0335859, 0.119722, 0.73504, 0.0864127, 0.362425, 0.176161, -0.0204343, 0.859263, + 0.194278, 0.251436, -0.239351, -0.0507662, -0.434827, 0.681383, -0.851639, 0.314206, 0.960214, -0.765246, + 0.0943579, -1.3527, -0.383905, 0.295471, 0.186272, -0.524005, 0.231835, 0.631964, 0.563593, -0.233032, + -0.681905, -0.395179, -0.0226096, 1.05588, 0.0358864, 0.0698283, -0.140077, -0.520076, -0.0381934, + 0.748649, 0.455763, 0.128725, 0.84913, 0.684801, -0.121069, -1.2179, -0.093275, 0.060867, 1.11172, + -0.509682, -0.51808, 0.598526, 0.0785892, -0.153047, -0.356353, 0.299113, 0.301137, -0.099219, 0.743272, + -0.342009, -1.49835, -0.317872, 0.0881915, -0.452291, 0.550556, -0.00172722, 0.0964582, 0.0452602, + -0.147162, 0.405776, -0.0960432, 0.122883, -0.0658444, 0.385776, 0.24054, -0.417454, -0.151232, 0.292792, + -0.546667, -0.079013, -0.0408538, -0.246911, 0.42111, -0.234116, 1.93385, -0.0125717, 0.369206, -0.887121, + -0.35164, -0.0199082, -0.589104, 0.0378629, 0.532835, 1.44769, -0.523955, 0.224419, 0.631779, 0.428372, + 0.0290583, 1.15798, 0.538468, 0.220256, 1.06067, 0.172888, -0.213931, -0.937005, -0.954402, -0.291733, + 2.34467, -0.197366, 0.186246, 1.04031, 0.0550898, 0.571817, 0.31744, 0.175624, 0.262095, -0.488965, + 0.150908, -0.386822, -1.20623, 0.251761, 0.00833788, 0.182373, 0.704209, 0.086891, -0.150136, 0.354458, + -0.189802, -0.923194, -0.779221, 0.112742, 1.04332, 0.107141, -0.189807, -0.648034, 0.560244, 0.00601226, + 2.65893, 0.625814, 0.385054, 1.1664, -0.635675, 0.274811, 1.64935, -0.0480064 + ], + "dims": [2, 4, 4, 4, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization non-spatial mode", + "operator": "BatchNormalization", + "opset": { "domain": "", "version": 7 }, + "attributes": [{ "name": "spatial", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[3,1,2]", + "inputs": [ + { + "data": [0.2134, 0.32434, 0.5644, 0.3234, 0.4545, 0.3445], + "dims": [3, 1, 2], + "type": "float32" + }, + { + "data": [0.5, 0.6], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.2, 0.1], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.034, 0.342], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [1, 1], + "dims": [1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.2897, 0.089404, 0.4652, 0.08884, 0.41025, 0.1015], + "dims": [3, 1, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "BatchNormalization non-spatial mode - NHWC", + "operator": "BatchNormalization", + "opset": { "domain": "com.ms.internal.nhwc", "version": 7 }, + "attributes": [{ "name": "spatial", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[3,2,1]", + "inputs": [ + { + "data": [0.2134, 0.32434, 0.5644, 0.3234, 0.4545, 0.3445], + "dims": [3, 2, 1], + "type": "float32" + }, + { + "data": [0.5, 0.6], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.2, 0.1], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.034, 0.342], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [1, 1], + "dims": [1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.2897, 0.089404, 0.4652, 0.08884, 0.41025, 0.1015], + "dims": [3, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 219e15eb4648f..2e8eaaba191d0 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -126,7 +126,7 @@ ] }, { - "name": "conv with bias addition C", + "name": "conv with bias addition C - NHWC", "operator": "Conv", "inputShapeDefinitions": "rankOnly", "opset": { "domain": "", "version": 17 }, @@ -158,6 +158,36 @@ "type": "float32" } ] + }, + { + "name": "inChannel = 3, outChannel = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [4, 3, 2, 2], + "type": "float32" + }, + { + "data": [5, 6, 7, 8], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [360, 334, 271, 323, 909, 963, 1024, 1028, 683, 655, 576, 650, 473, 508, 570, 677], + "dims": [1, 4, 2, 2], + "type": "float32" + } + ] } ] }, diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 460122b4e085c..35888e2fc3709 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -85,5 +85,34 @@ ] } ] + }, + { + "name": "Expand 5D - float32", + "operator": "Expand", + "attributes": [], + "cases": [ + { + "name": "Expand 5 - float32", + "inputs": [ + { + "data": [1], + "dims": [1, 1, 1, 1, 1], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 6], + "dims": [5], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 1, 1, 1, 1, 1], + "dims": [1, 1, 1, 1, 6], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 37aa9394c7f96..a313adef7151b 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1337,6 +1337,7 @@ //"and.jsonc", "asin.jsonc", "attention.jsonc", + "batch-norm.jsonc", "bias-add.jsonc", "bias-split-gelu.jsonc", "ceil.jsonc", diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index 0146e81c6cf8c..fb7091592c16e 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -34,17 +34,17 @@ constexpr int NumReduceDim = 3; template auto GetCKGroupNormNHWCTypeStringAndOps() { - using InDataType = typename CKDataTypeAdaptor::type; - using OutDataType = typename CKDataTypeAdaptor::type; - using AccDataType = typename CKDataTypeAdaptor::type; + using XDataType = typename CKDataTypeAdaptor::type; + using YDataType = typename CKDataTypeAdaptor::type; + using SaveMeanInvStdDataType = typename CKDataTypeAdaptor::type; using GammaDataType = float; using BetaDataType = float; using Activation = std::conditional_t; std::vector>>> ret; - for (auto&& impl : internal::GetDeviceGroupNormInstances()) { + for (auto&& impl : internal::GetDeviceGroupNormInstances()) { std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; auto invoker = impl->MakeInvokerPointer(); @@ -69,6 +69,8 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { gamma_beta_strides, // gammaStrides gamma_beta_strides, // betaStrides in_out_strides, // yStrides + {0, 0}, // saveMeanStrides + {0, 0}, // saveInvStdStrides reduce_dims, // reduceDims params->epsilon, params->src, diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 88443478cf521..19b081881dcec 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -6,8 +6,8 @@ #ifdef USE_COMPOSABLE_KERNEL #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" #include "ck/utility/data_type.hpp" namespace onnxruntime { @@ -21,102 +21,104 @@ using F32 = float; using Swish = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; -using ck::tensor_operation::device::DeviceNormalization; // the interface -using ck::tensor_operation::device::DeviceNormalizationImpl; // the implementation +using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface +using ck::tensor_operation::device::DeviceNormalizationFwdImpl; // the implementation + +// See https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1fefd82ed8/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp template using device_normalization_f32_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; template -using device_normalization_f16_instances = std::tuple< +using device_normalization_f16_instances = // clang-format off - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, OutElementwise, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, // irregular size - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; // Use this function to get implementation -template -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { return {}; } template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances< - F16, F32, F32, F32, F16, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Swish, 5, 3>(); template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances< - F16, F32, F32, F32, F16, Pass, 5, 3>(); + F16, F32, F32, F16, F32, Pass, 5, 3>(); template <> -std::vector>> GetDeviceGroupNormInstances< F32, F32, F32, F32, F32, Swish, 5, 3>(); template <> -std::vector>> GetDeviceGroupNormInstances< F32, F32, F32, F32, F32, Pass, 5, 3>(); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu index d1dd78e3452da..6718f29268031 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu @@ -4,7 +4,6 @@ #ifdef USE_COMPOSABLE_KERNEL #include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" namespace onnxruntime { namespace contrib { @@ -12,9 +11,9 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f16_instances{}); @@ -23,9 +22,9 @@ GetDeviceGroupNormInstances() { } template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f16_instances{}); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 97baed34a341d..9b0ccab17b4c1 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -4,7 +4,6 @@ #ifdef USE_COMPOSABLE_KERNEL #include "contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" namespace onnxruntime { namespace contrib { @@ -12,9 +11,9 @@ namespace rocm { namespace internal { template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { - std::vector>> instances; + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f32_instances{}); @@ -23,9 +22,9 @@ GetDeviceGroupNormInstances() { } template <> -std::vector>> +std::vector>> GetDeviceGroupNormInstances() { - std::vector>> instances; + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, device_normalization_f32_instances{}); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 526d220d4be24..b7b9441ac997d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -77,7 +77,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { params->epsilon}; // Grid dim is (batch_count, groups, 1) - return LaunchTritonKernel(params->stream, i, params->n, params->groups, 1, &args, sizeof(args)); + return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); }; ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); } diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index 6e0eb460d2a63..eca1221e84cb8 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -37,5 +38,32 @@ inline InlinedVector SplitString(std::string_view string_to_sp return result; } +/** + * Trim a string from start inplace. + * @param s The string to trim. + */ +inline void TrimStringFromLeft(std::string& s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); +} + +/** + * Trim a string from end inplace. + * @param s The string to trim. + */ +inline void TrimStringFromRight(std::string& s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end()); +} + +/** + * Trim a string from both ends. + * @param s The string to trim. + * @return The trimmed string. + */ +inline std::string TrimString(std::string s) { + TrimStringFromRight(s); + TrimStringFromLeft(s); + return s; +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc index ea93db58339c7..4f5fa9910b5df 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc @@ -53,128 +53,200 @@ Status AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(KernelTypeStrRe // clang-format off constexpr uint8_t kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes[] = { 0x10, 0x00, 0x00, 0x00, 0x6b, 0x74, 0x73, 0x72, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0xbc, 0x06, 0x00, 0x00, - 0x4c, 0x02, 0x00, 0x00, 0xe0, 0x01, 0x00, 0x00, 0xe0, 0x00, 0x00, 0x00, 0x14, 0x06, 0x00, 0x00, - 0x88, 0x01, 0x00, 0x00, 0xb8, 0x05, 0x00, 0x00, 0x1c, 0x05, 0x00, 0x00, 0x18, 0x07, 0x00, 0x00, - 0xcc, 0x04, 0x00, 0x00, 0x0c, 0x01, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x54, 0x05, 0x00, 0x00, - 0x3c, 0x06, 0x00, 0x00, 0xf8, 0x02, 0x00, 0x00, 0x7c, 0x02, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x38, 0x03, 0x00, 0x00, 0xec, 0xf8, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, + 0x4c, 0x0b, 0x00, 0x00, 0xac, 0x08, 0x00, 0x00, 0xd0, 0x0a, 0x00, 0x00, 0x10, 0x06, 0x00, 0x00, + 0xa8, 0x07, 0x00, 0x00, 0x18, 0x03, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x44, 0x07, 0x00, 0x00, 0x9c, 0x01, 0x00, 0x00, 0xf8, 0x07, 0x00, 0x00, 0x78, 0x09, 0x00, 0x00, + 0x14, 0x01, 0x00, 0x00, 0x50, 0x06, 0x00, 0x00, 0x60, 0x02, 0x00, 0x00, 0xf4, 0x08, 0x00, 0x00, + 0x8c, 0x03, 0x00, 0x00, 0x9c, 0x02, 0x00, 0x00, 0x84, 0x06, 0x00, 0x00, 0xcc, 0x03, 0x00, 0x00, + 0x60, 0x05, 0x00, 0x00, 0xb8, 0x01, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00, 0x08, 0x04, 0x00, 0x00, + 0xe0, 0x09, 0x00, 0x00, 0x8c, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xf4, 0xff, 0xff, + 0x08, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xda, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf4, 0xff, 0xff, + 0xd8, 0xf4, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, + 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x10, 0xf5, 0xff, 0xff, 0xa4, 0x0a, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfc, 0xf4, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x2c, 0xf5, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x4e, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x48, 0xf5, 0xff, 0xff, 0xc8, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x30, 0xf5, 0xff, 0xff, 0x6c, 0xf5, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x31, 0x39, 0x00, 0x00, 0x9c, 0xf5, 0xff, 0xff, 0x3c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc2, 0xf5, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x94, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xc4, 0xf5, 0xff, 0xff, + 0xe8, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xb4, 0xf5, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xac, 0xf5, 0xff, 0xff, + 0xe8, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x10, 0xf6, 0xff, 0xff, 0xac, 0x05, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x36, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xf8, 0xf5, 0xff, 0xff, 0x34, 0xf6, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, + 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, + 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0x74, 0xf6, 0xff, 0xff, + 0x38, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x64, 0xf6, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x5c, 0xf6, 0xff, 0xff, + 0x98, 0xf6, 0xff, 0xff, 0x40, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbe, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x90, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc0, 0xf6, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0xe4, 0xf6, 0xff, 0xff, + 0x2c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x0a, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xcc, 0xf6, 0xff, 0xff, + 0x08, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x30, 0xf7, 0xff, 0xff, 0xe0, 0x08, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x56, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0xf7, 0xff, 0xff, 0x54, 0xf7, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, + 0x78, 0xf7, 0xff, 0xff, 0x98, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9e, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x60, 0xf7, 0xff, 0xff, 0x9c, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, 0x77, 0x63, 0x4d, 0x61, - 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0x20, 0xf9, 0xff, 0xff, 0xf0, 0x06, 0x00, 0x00, + 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0xd0, 0xf7, 0xff, 0xff, 0x40, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0e, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x08, 0xf9, 0xff, 0xff, 0x44, 0xf9, 0xff, 0xff, + 0xf6, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xf7, 0xff, 0xff, 0xf4, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x6c, 0xf9, 0xff, 0xff, 0xa4, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5a, 0xf9, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x54, 0xf9, 0xff, 0xff, 0x90, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, 0xb4, 0xf9, 0xff, 0xff, - 0x5c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xa2, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf9, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0x1c, 0xf8, 0xff, 0xff, 0xf4, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xf8, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xf8, 0xff, 0xff, 0x40, 0xf8, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, + 0x68, 0xf8, 0xff, 0xff, 0xa8, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xf8, 0xff, 0xff, 0x8c, 0xf8, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0xf4, 0x00, 0x00, 0x00, 0xc8, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, + 0x0c, 0x01, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, + 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, + 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, + 0xd8, 0xf8, 0xff, 0xff, 0xdc, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xc4, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xf4, 0xf8, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x22, 0xf9, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xf4, 0xf8, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x24, 0xf9, 0xff, 0xff, + 0xe4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0xf9, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0x40, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf9, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x68, 0xf9, 0xff, 0xff, 0x70, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xf9, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, + 0x60, 0xf9, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x90, 0xf9, 0xff, 0xff, 0x1c, 0x05, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x80, 0xf9, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x78, 0xf9, 0xff, 0xff, 0xb4, 0xf9, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0xd8, 0xf9, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xb4, 0x01, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0xfa, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x1c, 0xfa, 0xff, 0xff, 0xf4, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0a, 0xfa, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x04, 0xfa, 0xff, 0xff, 0x40, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, - 0x68, 0xfa, 0xff, 0xff, 0x3c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x56, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x50, 0xfa, 0xff, 0xff, 0x8c, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, - 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0xb4, 0xfa, 0xff, 0xff, - 0x00, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xfc, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd0, 0xfa, 0xff, 0xff, 0x40, 0x05, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, + 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x04, 0xfa, 0xff, 0xff, + 0x84, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xf0, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x20, 0xfa, 0xff, 0xff, 0xf0, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbe, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xfa, 0xff, 0xff, 0xf4, 0xfa, 0xff, 0xff, + 0x46, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x08, 0xfa, 0xff, 0xff, 0x44, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, - 0x31, 0x31, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0x98, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x64, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x38, 0xfb, 0xff, 0xff, 0xd8, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x26, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x20, 0xfb, 0xff, 0xff, 0x5c, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, - 0x88, 0xfb, 0xff, 0xff, 0x88, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x70, 0xfb, 0xff, 0xff, 0xac, 0xfb, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x00, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd4, 0xfb, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, - 0x31, 0x00, 0x00, 0x00, 0xfc, 0xfb, 0xff, 0xff, 0x14, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xea, 0xfb, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xe4, 0xfb, 0xff, 0xff, 0x20, 0xfc, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x38, 0x01, 0x00, 0x00, 0xdc, 0x00, 0x00, 0x00, - 0xa8, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, - 0x48, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, - 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, - 0x76, 0x3a, 0x31, 0x00, 0x6c, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbc, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x90, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe4, 0xfc, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, - 0xb8, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, - 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0c, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xe0, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd6, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x3c, 0xfd, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x10, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x64, 0xfd, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, - 0x6c, 0xfd, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x40, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x94, 0xfd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, - 0x68, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x54, 0x31, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbc, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x58, 0xfd, 0xff, 0xff, 0x94, 0xfd, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, - 0xb8, 0xfd, 0xff, 0xff, 0x58, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa6, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xa0, 0xfd, 0xff, 0xff, 0xdc, 0xfd, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0xff, - 0xa0, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xf2, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xec, 0xfd, 0xff, 0xff, - 0x28, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, - 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x50, 0xfe, 0xff, 0xff, 0xc0, 0x01, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x3e, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xfe, 0xff, 0xff, 0x74, 0xfe, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, - 0x00, 0x00, 0x00, 0x00, 0x9c, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x92, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfe, 0xff, 0xff, - 0xc8, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xfe, 0xff, 0xff, 0x20, 0x01, 0x00, 0x00, + 0x31, 0x31, 0x00, 0x00, 0x6c, 0xfa, 0xff, 0xff, 0xc4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x88, 0xfa, 0xff, 0xff, 0x88, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xae, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x70, 0xfa, 0xff, 0xff, 0xac, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0xd0, 0xfa, 0xff, 0xff, 0x40, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xde, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd8, 0xfe, 0xff, 0xff, 0x14, 0xff, 0xff, 0xff, + 0xf6, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xfa, 0xff, 0xff, 0xf4, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x3c, 0xff, 0xff, 0xff, 0xd4, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2a, 0xff, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x24, 0xff, 0xff, 0xff, 0x60, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x1c, 0xfb, 0xff, 0xff, 0xf4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xfb, 0xff, 0xff, 0x40, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, + 0x68, 0xfb, 0xff, 0xff, 0xa8, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xfb, 0xff, 0xff, 0x8c, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfb, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe2, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xa4, 0xfb, 0xff, 0xff, 0xe0, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, + 0x08, 0xfc, 0xff, 0xff, 0x08, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2e, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xf0, 0xfb, 0xff, 0xff, 0x2c, 0xfc, 0xff, 0xff, 0x04, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x18, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x48, 0xfc, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x31, 0x30, 0x00, 0x00, 0x7c, 0xfc, 0xff, 0xff, 0x30, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xfc, 0xff, 0xff, 0x94, 0xfc, 0xff, 0xff, + 0x44, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xba, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfc, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0xbc, 0xfc, 0xff, 0xff, 0x4c, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xd8, 0xfc, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, + 0x00, 0x00, 0x00, 0x00, 0x0c, 0xfd, 0xff, 0xff, 0xcc, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x32, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x34, 0xfd, 0xff, 0xff, + 0x78, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x24, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x1c, 0xfd, 0xff, 0xff, + 0x58, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x80, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x78, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xa8, 0xfd, 0xff, 0xff, 0x68, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xce, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x90, 0xfd, 0xff, 0xff, 0xcc, 0xfd, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, + 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x28, 0xfe, 0xff, 0xff, 0x84, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0xff, 0x40, 0xfe, 0xff, 0xff, 0x98, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x66, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x68, 0xfe, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, + 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0xa4, 0xfe, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9c, 0xfe, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x94, 0xfe, 0xff, 0xff, 0xd0, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfe, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xd0, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, - 0x88, 0xff, 0xff, 0xff, 0x88, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x70, 0xff, 0xff, 0xff, 0xac, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0xdc, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x28, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x20, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x50, 0xff, 0xff, 0xff, 0xc0, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x76, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xff, 0xff, 0xff, 0x74, 0xff, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, + 0x00, 0x00, 0x00, 0x00, 0xac, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa4, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd4, 0xff, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, }; // clang-format on diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index 03ad95260c0ad..c8960578f9e3d 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -101,6 +101,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function +struct BlockwiseQuantization { + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(sizeof(ElementT) == 2, "Only 16b floating point types are supported!"); + + using QuantBlocking = + std::conditional_t, + MatrixShape<1, block_size>>; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + // We pack 4 weights into one 16b element, so we can leverage cutlass tile iterators + // for async share memory loading, and minimizing bank conflict during matrix loading + using ElementWPack = ElementT; + using LayoutWPack = ColumnMajorLayout; // <- layout of packed weight, must be column major + + // Current Ampere kernel use 8b zero point, need to shrink it to 4b in the future + using ElementQOffset = uint8_t; + + // Layout of the quantization parameters (scales and zero points) + // Major on the dimension that has the most parameters per squarish weight block. + // E.g. for column-wise quantization, a [64, 64] block has [2, 64] parameters, + // where each row has more data, so we use row major layout so that warp threads + // can use less load instructions to load more parameters. + using LayoutQmeta = + typename std::conditional::type; + + /** + * @brief Get quantized weight tensor dimensions. + * Actual weight type is int4, we use ElementW = uint8 to avoid possible compilation + * troubles. Since the layout is column major, we are packing 2 weights in a column + * into one int8 + */ + static inline auto get_quant_weights_shape(int rows, int columns) { + return make_Position(rows / 2, columns); + } + + static inline auto get_quant_meta_shape(int rows, int columns) { + return make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + } + + /** + * @brief Prepack weight matrix to facilitate matrix loading, depending on MMA + * instruction layout. + * + * The weight matrix is int4, yet we want to leverage existing fp16/bf16 + * tile loading and MMA layout code in CUTLASS. So we group 4 int4 into 2 + * bytes, pretending it's fp16. This grouping must be done in a way to be + * easily unpacked into tiles that match the MMA instruction layout. + * For MMA instruction <16, 8, 16>, each instruction processes 2 8x8 tiles, + * vertically stacked on the K dimension. And MmaTensorOpMultiplicandTileIterator + * loads a tile. + * + * So we stack 2x2 tiles on a 3rd dimeansion, and reshape them in a HWC fashion: + * T0, T2 + * T1, T3 + * ==> + * T0[0, 0], T1[0, 0], T2[0, 0], T3[0, 0] + * T0[1, 0], T1[1, 0], T2[1, 0], T3[1, 0] + * T0[2, 0], T1[2, 0], T2[2, 0], T3[2, 0] + * T0[3, 0], T1[3, 0], T2[3, 0], T3[3, 0] + * ... + * T0[0, 7], T1[0, 7], T2[0, 7], T3[0, 7] + * T0[1, 7], T1[1, 7], T2[1, 7], T3[1, 7] + * T0[2, 7], T1[2, 7], T2[2, 7], T3[2, 7] + * T0[3, 7], T1[3, 7], T2[3, 7], T3[3, 7] + * + * This pack a 8x16 int8 tile into a 16x8 int8 tile, i.e. a 8x8 16b tile + */ + static void prepack_weights( + int rows, + int columns, + const gsl::span& weights, // <- int4 weights, column major + const gsl::span& weights_prepacked // <- int4 prepacked weights tensor, same size buffer + ) { + ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 && + (rows % QuantBlocking::kRow) == 0 && + (columns % QuantBlocking::kColumn) == 0, + "Does not support odd number of rows or columns!"); + ORT_ENFORCE(weights.size() == size_t(rows * columns / 2), + "Weight tensor shape mismatch!"); + ORT_ENFORCE(weights_prepacked.size() == weights.size(), + "Prepacked Weight tensor buffer should be the same size!"); + + const MatrixRef + tensor_weight(weights, make_Position(rows / 2, columns)); + const MatrixRef + tensor_weight_prepacked(weights_prepacked, make_Position(rows, columns / 2)); + + // TODO(fuchen)!! parallized this. + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } + } + + /** + * @brief We rearrange the values of the quantization scale and offset tensors + * to facilitate faster loading to tensor core, only 16b gemm, and (1,n) + * block quantization. + */ + static constexpr bool ShouldRearrangeMeta = sizeof(ElementT) == 2 && QuantBlocking::kRow == 1; + + static void prepack_quant_scales( + size_t rows, + size_t columns, + const gsl::span& scales, // <- quant scales, column major layout + const gsl::span& scales_prepacked // <- quant scales prepacked, same size buffer + ) { + auto meta_shape = get_quant_meta_shape(rows, columns); + ORT_ENFORCE(scales.size() == size_t(meta_shape.product()), + "Quantization scale tensor shape mismatch!"); + ORT_ENFORCE(scales_prepacked.size() == size_t(meta_shape.product()), + "Prepacked quantization scale tensor buffer should be the same size!"); + + MatrixRef tensor_scale(scales, meta_shape); + MatrixRef tensor_scale_prepacked(scales_prepacked, meta_shape); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + // Potential transpose if the prepacked layout is different from the original layout + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row = 0; row < tensor_scale.shape()[0]; ++row) { + tensor_scale_prepacked.at(row, col) = tensor_scale.at(row, col); + } + } + } + } + + static void prepack_quant_offsets( + size_t rows, + size_t columns, + const gsl::span& offsets, // <- quant offsets, int4, column major layout + const gsl::span& offsets_prepacked // <- quant offsets prepacked, double size buffer + ) { + auto meta_shape = get_quant_meta_shape(rows, columns); + + ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0, + "Does not support odd number of rows or columns!"); + ORT_ENFORCE(offsets_prepacked.size() == size_t(meta_shape.product()), + "Wrong buffer size for prepacked quantization offsets!"); + ORT_ENFORCE(offsets.size() == size_t(((meta_shape[0] + 1) / 2) * meta_shape[1]), + "Quantization offset tensor shape mismatch!"); + + MatrixRef + tensor_offset(offsets, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); + MatrixRef tensor_offset_prepacked(offsets_prepacked, meta_shape); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row_blk = 0; row_blk < meta_shape[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + // Potential transpose if the prepacked layout is different from the original layout + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_offset_prepacked.at(row + 0, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_offset_prepacked.at(row + 1, col) = pair01 >> 4; + } + } + } + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 48d975a7fd26d..b5784ecb56d01 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -779,6 +779,17 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + template void MlasBlockwiseQuantizedShape( @@ -790,6 +801,16 @@ MlasBlockwiseQuantizedShape( int& q_cols ); +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); void MLASCALL MlasBlockwiseQuantizedBufferSizes( diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index c1397e92d9d26..3d6251a694cfb 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -77,7 +77,6 @@ #include "orttraining/core/optimizer/bias_softmax_dropout_fusion.h" #include "orttraining/core/optimizer/bitmask_dropout_replacement.h" #include "orttraining/core/optimizer/sce_loss_grad_bias_fusion.h" -#include "orttraining/core/optimizer/memory_optimizer.h" #endif #ifdef ENABLE_TRITON #include "orttraining/core/optimizer/triton_fusion.h" @@ -354,18 +353,6 @@ InlinedVector> GenerateTransformers( // fusions might be prevented if this one removes a Q/DQ node too early. transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); -#ifdef ENABLE_TRAINING - // Put memory optimization transformer at last (which is done after most of fusions are done) by intention. - // Known issue: after memory optimization is completed, if some fusion happens, it is possible that the - // node priority got changed. This may disorder the execution order of nodes to recompute. - // TODO(pengwa): need to fix this issue. - const std::string enable_memory_optimizer = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerEnabler, ""); - const std::string probe_level = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeLevel, "0"); - transformers.emplace_back(std::make_unique(enable_memory_optimizer, probe_level)); -#endif - } break; case TransformerLevel::Level3: { diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h index 91e21b655f8bd..cfa02c916b73f 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h @@ -20,6 +20,10 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { // @@region_begin(extended_minimal_build_required_kernels)@@ // kOnnxDomain ops + OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 10}, + OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 13}, + OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 19}, + // OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 21}, pending CPU EP adding support OpIdentifierWithStringViews{kOnnxDomain, "Gather", 1}, OpIdentifierWithStringViews{kOnnxDomain, "Gather", 11}, OpIdentifierWithStringViews{kOnnxDomain, "Gather", 13}, @@ -28,6 +32,10 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { OpIdentifierWithStringViews{kOnnxDomain, "Identity", 14}, OpIdentifierWithStringViews{kOnnxDomain, "Identity", 16}, OpIdentifierWithStringViews{kOnnxDomain, "Identity", 19}, + OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 10}, + OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 13}, + OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 19}, + // OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 21}, pending CPU EP adding support OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 1}, OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 11}, OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 13}, @@ -39,8 +47,10 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { #if !defined(DISABLE_CONTRIB_OPS) // kMSDomain ops + OpIdentifierWithStringViews{kMSDomain, "DequantizeLinear", 1}, OpIdentifierWithStringViews{kMSDomain, "NhwcMaxPool", 1}, OpIdentifierWithStringViews{kMSDomain, "QLinearConv", 1}, + OpIdentifierWithStringViews{kMSDomain, "QuantizeLinear", 1}, #endif // !defined(DISABLE_CONTRIB_OPS) // @@region_end(extended_minimal_build_required_kernels)@@ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index f42766267b0f9..3d2a81ce7f8cd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -87,12 +87,19 @@ std::vector WhereMoves() { MoveAll(q, ArgType::kOutput)}; return moves; } -QDQReplaceWithNew SplitReplacer() { +QDQReplaceWithNew SplitReplacer(bool has_split_as_input) { NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; - std::vector moves{ - MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput), - MoveAll(q, ArgType::kOutput)}; + std::vector moves{MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput)}; + + if (has_split_as_input) { + // Move the optional split input to the new node. + moves.push_back(MoveAndAppend(target, ArgType::kInput, 1, ArgType::kInput, true)); + } + + moves.push_back(MoveAll(q, ArgType::kOutput)); + return QDQReplaceWithNew(kOnnxDomain, "Split", std::move(moves)); } @@ -247,7 +254,12 @@ MatMulReplaceWithQLinear::MatMulReplaceWithQLinear() } Status SplitReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { - return SplitReplacer().Run(graph, selected_nodes); + const auto& target_node = selected_nodes.Target(); + const auto& input_defs = target_node.InputDefs(); + + // The 'split' attribute became an optional input at opset 13. + bool has_split_as_input = target_node.SinceVersion() >= 13 && input_defs.size() == 2; + return SplitReplacer(has_split_as_input).Run(graph, selected_nodes); } Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 0e383c3031ca6..29178fe87f75c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -20,7 +20,7 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { const std::string action_name{"dropSplitQDQ"}; std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); + std::unique_ptr selector = std::make_unique(true /*req_equal_quant_params*/); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Split", {}}}, std::move(selector), diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 3880288bdba2e..15b501c667046 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -253,7 +253,39 @@ void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder builder.num_input_defs = 1; // set to 1 as the first input is variadic } -void OutputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { +bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + const Node& dq_node = *dq_nodes.front(); + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + // All Q outputs should have same data type and (optionally) equal quantization parameters as the input. + for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) { + const Node& q_node = *q_nodes[q_idx]; + + if (dt_input != q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { + return false; + } + + if (req_equal_quant_params_ && + !IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) { + return false; + } + } + + return true; +} + +void SplitSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.num_output_defs = 1; // set to 1 as the first output is variadic } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index be7f7e0288eda..d0d7fb2c2af17 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -115,6 +115,24 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { bool allow_16bit_; }; +// DQ node -> Split -> multiple Q nodes with equal quantization types. +// Optionally, the selector can require all input and output quantization parameters to be +// equal and constant. +class SplitNodeGroupSelector : public NodeGroupSelector { + public: + explicit SplitNodeGroupSelector(bool req_equal_quant_params = false) + : req_equal_quant_params_(req_equal_quant_params) {} + + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + + bool req_equal_quant_params_; // If true, only selects a node group if the input and output + // quantization parameters are all equal/constant, which enables the + // optimizer to drop the Q/DQ ops if the group is assigned to the CPU EP. +}; + // DQ nodes for X, W and optionally B -> node -> Q class ConvNodeGroupSelector : public NodeGroupSelector { public: @@ -288,10 +306,11 @@ class InputVariadicSelector : public BaseSelector { void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; -// DQ -> node -> Variadic Q nodes -class OutputVariadicSelector : public BaseSelector { +// DQ -> Split -> variadic Q nodes +class SplitSelector : public BaseSelector { public: - OutputVariadicSelector() : BaseSelector(std::make_unique()) {} + SplitSelector(bool req_equal_quant_params = false) + : BaseSelector(std::make_unique(req_equal_quant_params)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 1a4d3a0c18151..544fe82a268c8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -27,6 +27,9 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops } /* static methods to return different operator's OpVersionMap */ + +// These are operators that do not change the data and therefore the input DQ and +// output Q have the same scale and zero_point. static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, {"Reshape", {}}, @@ -35,7 +38,6 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Transpose", {}}, {"MaxPool", {12}}, {"Resize", {}}, - {"Split", {}}, {"Squeeze", {}}, {"Unsqueeze", {}}, {"Tile", {}}}; @@ -81,7 +83,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { {"Neg", {}}, {"DepthToSpace", {}}, {"SpaceToDepth", {}}, - {"Clip", {}}}; + {"Clip", {}}, + {"LpNormalization", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, @@ -97,6 +100,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { {"Max", {}}, {"Min", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetSplitOpVersionsMap() { + return {{"Split", {}}}; +} static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; } @@ -170,6 +176,13 @@ void RegisterVariadicSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterSplitSelector(Selectors& qdq_selectors) { + /* register selectors for Split op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetSplitOpVersionsMap(), + std::move(selector)); +} + void RegisterConvSelector(Selectors& qdq_selectors) { /* register selector for conv op */ std::unique_ptr selector = std::make_unique(); @@ -247,6 +260,7 @@ void SelectorManager::CreateSelectors() { RegisterUnarySelectors(qdq_selectors_); RegisterBinarySelectors(qdq_selectors_); RegisterVariadicSelectors(qdq_selectors_); + RegisterSplitSelector(qdq_selectors_); RegisterConvSelector(qdq_selectors_); RegisterConvTransposeSelector(qdq_selectors_); RegisterMatMulSelector(qdq_selectors_); diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 81b415c2e40ae..c479b685f9267 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -19,6 +19,9 @@ namespace onnx_transpose_optimization { /////// /////// /* Small utilities for editing nodes and manipulating axes/permutations */ +static constexpr bool IsOnnxDomain(std::string_view domain) { + return (domain == onnxruntime::kOnnxDomain) || (domain == onnxruntime::kOnnxDomainAlias); +} static std::vector DataInt64(api::TensorRef& tensor) { std::vector raw_data = tensor.Data(); @@ -95,21 +98,94 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); } +// Use to create a QuantizeLinear or DequantizeLinear node. Does not update output ValueInfo. Adds axis if needed. +static std::unique_ptr MakeQOrDQ(api::GraphRef& graph, std::string_view domain, std::string_view op_type, + std::vector inputs, + std::optional axis) { + std::unique_ptr node = graph.AddNode(op_type, inputs, /* num_outputs */ 1, domain); + // only set if provided and not the default + if (axis && axis != 1) { + node->SetAttributeInt("axis", *axis); + } + + return node; +} + +// Returns whether perm is a valid permutation (contains each value from 0 to perm.size() - 1 exactly once) +static bool IsValidPerm(const std::vector& perm) { + size_t rank = perm.size(); + int64_t rank_int = gsl::narrow_cast(rank); + std::vector used_dims(rank); + for (size_t i = 0; i < rank; ++i) { + int64_t x = perm[i]; + size_t x_size_t = gsl::narrow_cast(x); + if (x < 0 || x >= rank_int || used_dims[x_size_t]) { + return false; + } + used_dims[x_size_t] = true; + } + return true; +} + +static std::optional> GetPermAttrIfValid(const api::NodeRef& node) { + std::optional> perm = node.GetAttributeInts("perm"); + if (perm.has_value() && !IsValidPerm(*perm)) { + return std::nullopt; + } + return perm; +} + +static inline bool NormalizeAndValidateAxis(int64_t& axis, size_t rank) { + int64_t rank_int = gsl::narrow_cast(rank); + if (axis < 0) { + axis += rank_int; + } + + return axis >= 0 && axis < rank_int; +} + +/// +/// Check if an output value has a single consumer that is a node. +/// +/// Consumer node if found. +/// True if there is a single consumer node. +static bool OutputValueHasSingleConsumerNode(const api::GraphRef& graph, const api::NodeRef& node, size_t output_idx, + std::unique_ptr& single_consumer) { + auto value = node.Outputs()[output_idx]; + auto consumers = graph.GetValueConsumers(value); + + if (consumers->comprehensive && (consumers->nodes.size() == 1)) { + single_consumer = std::move(consumers->nodes[0]); + } else { + single_consumer.reset(); + } + + return single_consumer != nullptr; +} + +/// return the DQ node if value_name is produced by a DQ node +static std::unique_ptr GetDQIfProducingValue(const api::GraphRef& graph, std::string_view value_name) { + auto maybe_dq_node = graph.GetNodeProducingOutput(value_name); + + return (maybe_dq_node != nullptr && maybe_dq_node->OpType() == "DequantizeLinear") ? std::move(maybe_dq_node) + : std::unique_ptr(); +} + /// -/// Return a DequantizeLinear node if it's input is a constant initializer with known consumers. +/// Return a DequantizeLinear node if it's input is a constant initializer and it has a single consumer. /// In this case the initializer can be updated in-place by UnsqueezeInput or TransposeInput. /// /// Current graph -/// Value to check if produced by a DQ node who's input is a constant initializer +/// Value to check if produced by a DQ node who's input is a constant initializer /// NodeRef for DQ node if it meets the requirements. -static std::unique_ptr GetDQWithConstInitializerInput(const api::GraphRef& graph, - std::string_view dq_output_name) { - std::unique_ptr dq_node; - auto maybe_dq_node = graph.GetNodeProducingOutput(dq_output_name); +static std::unique_ptr GetDQWithConstInitializerInputAndSingleConsumer(const api::GraphRef& graph, + std::string_view value_name) { + std::unique_ptr result; + auto dq_node = GetDQIfProducingValue(graph, value_name); - if (maybe_dq_node && maybe_dq_node->OpType() == "DequantizeLinear") { + if (dq_node) { do { - auto dq_input = maybe_dq_node->Inputs()[0]; + auto dq_input = dq_node->Inputs()[0]; auto dq_constant = graph.GetConstant(dq_input); // input to DQ must be a constant initializer @@ -117,10 +193,9 @@ static std::unique_ptr GetDQWithConstInitializerInput(const api::G break; } - // For now keep it simple and don't support per-axis quantization as that would require updating the - // scale and zero point values in the DQ node to re-order if transposing, or reshape if unsqueezing. - // the rank of the `scale` and `zero point` inputs must match so we only need to check `scale`. - auto dq_scale = graph.GetConstant(maybe_dq_node->Inputs()[1]); + // For now keep it simple and don't support per-axis quantization as that would require updating the axis of + // the DQ node during TransposeInputImpl and UnsqueezeInput. + auto dq_scale = graph.GetConstant(dq_node->Inputs()[1]); if (!dq_scale || dq_scale->NumElements() != 1) { break; } @@ -131,41 +206,190 @@ static std::unique_ptr GetDQWithConstInitializerInput(const api::G break; } - // DQ output is only used by the node we're modifying. - auto dq_consumers = graph.GetValueConsumers(dq_output_name); - if (!dq_consumers->comprehensive || dq_consumers->nodes.size() != 1) { + std::unique_ptr consumer; + if (!OutputValueHasSingleConsumerNode(graph, *dq_node, 0, consumer)) { break; } - dq_node = std::move(maybe_dq_node); + result = std::move(dq_node); } while (false); } - return dq_node; + return result; } -// Returns whether perm is a valid permutation (contains each value from 0 to perm.size() - 1 exactly once) -static bool IsValidPerm(const std::vector& perm) { - size_t rank = perm.size(); - int64_t rank_int = gsl::narrow_cast(rank); - std::vector used_dims(rank); - for (size_t i = 0; i < rank; ++i) { - int64_t x = perm[i]; - size_t x_size_t = gsl::narrow_cast(x); - if (x < 0 || x >= rank_int || used_dims[x_size_t]) { - return false; - } - used_dims[x_size_t] = true; +/// +/// Insert a Q -> DQ pair after the node following the DQ by using scale and zp info from the preceding DQ node. +/// DQ -> next node => DQ -> next node -> Q -> DQ. +/// This is only called for Transpose and Unsqueeze nodes. +/// +/// DQ node. +/// Node following DQ node. +/// New DQ node at end of DQ -> next_node -> Q -> DQ. +/// True if insert was successful. +static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { + std::unique_ptr single_consumer_node; + if (!OutputValueHasSingleConsumerNode(graph, dq_node, 0, single_consumer_node)) { + // should never happen as caller should have checked previously + return false; } + + auto& next_node = *single_consumer_node; + assert(next_node.OpType() == "Transpose" || next_node.OpType() == "Unsqueeze"); + + const auto dq_domain = dq_node.Domain(); + const auto& dq_inputs = dq_node.Inputs(); + const bool is_transpose = next_node.OpType() == "Transpose"; + + const auto scale_input = dq_inputs[1]; + const auto scale_value_info = graph.GetValueInfo(scale_input); + std::optional zp_input; + std::optional> zp_value_info; + + auto scale_shape = scale_value_info->Shape(); + if (!scale_shape && is_transpose) { + // axis potentially needs updating due to the transpose but we don't have the required info to do it. + return false; + } + + if (dq_inputs.size() > 2) { + zp_input = dq_inputs[2]; + zp_value_info = graph.GetValueInfo(zp_input.value()); + } + + // per-axis quantization if not a scalar (shape is empty for scalar). + // note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization, + // so we have to check the shape. + auto update_dq_axis = scale_shape && !scale_shape->empty(); + int64_t axis = dq_node.GetAttributeIntDefault("axis", 1); + + if (update_dq_axis && is_transpose) { + // update axis. + auto perm = GetPermAttrIfValid(next_node); + assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid + NormalizeAndValidateAxis(axis, scale_shape->size()); + axis = InvertPerm(*perm)[gsl::narrow_cast(axis)]; + } + + auto next_node_output_name = next_node.Outputs()[0]; + auto next_node_output_shape = graph.GetValueInfo(next_node_output_name)->Shape(); + + // setup Q node inputs. we don't connect it to next_node yet as we will move the output of that to the new DQ first. + std::vector inputs = {"", scale_input}; + if (zp_input) { + inputs.push_back(zp_input.value()); + } + + // Add Q + auto new_q_node = MakeQOrDQ(graph, dq_domain, "QuantizeLinear", inputs, axis); + auto q_node_outputs = new_q_node->Outputs(); + + // copy value info from the dq input for the type information, and update the shape to match next_node's output + graph.CopyValueInfo(dq_node.Inputs()[0], q_node_outputs[0]); // Q produces same type as the dq_node input + auto q_node_value_info = graph.GetValueInfo(q_node_outputs[0]); + q_node_value_info->SetShape(next_node_output_shape ? &*next_node_output_shape : nullptr); + + // update input to connect the DQ to the Q we just added. re-use scale and zp. + inputs[0] = new_q_node->Outputs()[0]; + + // Add DQ + auto new_dq_node = MakeQOrDQ(graph, dq_domain, "DequantizeLinear", inputs, axis); + auto dq_node_outputs = new_dq_node->Outputs(); + + // straight copy of value info as the type and shape are the same as next_node's output + graph.CopyValueInfo(next_node_output_name, dq_node_outputs[0]); + + // move next_node output to the new DQ node in case it was a graph output, and connect next_node with the new Q node + graph.MoveOutput(next_node, 0, *new_dq_node, 0); + auto new_next_node_output_name = next_node.Outputs()[0]; + new_q_node->SetInput(0, new_next_node_output_name); + graph.CopyValueInfo(dq_node_outputs[0], new_next_node_output_name); + return true; } -static std::optional> GetPermAttrIfValid(const api::NodeRef& node) { - std::optional> perm = node.GetAttributeInts("perm"); - if (perm.has_value() && !IsValidPerm(*perm)) { - return std::nullopt; - } - return perm; +/// +/// Check if a DQ -> Q pair have matching type/scale/zero point. +/// If there's no operator between them, and they match, they are redundant and can be removed. +/// +/// True if they match. +static bool CheckQDQNodePairMatch(const api::GraphRef& graph, + const api::NodeRef& dq_node, const api::NodeRef& q_node) { + bool match = false; + + do { + if (dq_node.Domain() != q_node.Domain()) { + break; + } + + auto t1 = graph.GetValueInfo(dq_node.Inputs()[0])->DType(); + auto t2 = graph.GetValueInfo(q_node.Outputs()[0])->DType(); + + if (t1 == api::DataType::UNDEFINED || t2 == api::DataType::UNDEFINED || t1 != t2) { + break; + } + + auto dq_scale = dq_node.Inputs()[1]; + auto q_scale = q_node.Inputs()[1]; + + if (dq_scale != q_scale) { + auto dq_scale_value = graph.GetConstant(dq_scale); + auto q_scale_value = graph.GetConstant(q_scale); + if (!dq_scale_value || !q_scale_value) { + break; // non-const input + } + + if (dq_scale_value->Data() != q_scale_value->Data()) { + break; + } + } + + auto dq_zp = dq_node.Inputs().size() > 2 ? dq_node.Inputs()[2] : ""; + auto q_zp = q_node.Inputs().size() > 2 ? q_node.Inputs()[2] : ""; + + if (dq_zp != q_zp) { + std::optional> dq_scale_value; + std::optional> q_scale_value; + if (dq_zp != "") { + dq_scale_value = graph.GetConstant(dq_zp); + if (!dq_scale_value.value()) { + break; // non-const input + } + } + + if (q_zp != "") { + q_scale_value = graph.GetConstant(q_zp); + if (!q_scale_value.value()) { + break; // non-const input + } + } + + if (dq_scale_value.has_value() && q_scale_value.has_value()) { + if (dq_scale_value->get()->Data() != q_scale_value->get()->Data()) { + break; + } + } else { + // check the input with a value matches the default zp value of 0 + if (dq_scale_value.has_value()) { + auto data = dq_scale_value->get()->Data(); + if (!std::all_of(data.begin(), data.end(), [](auto value) { return value == 0; })) { + break; + } + } else { + // q_scale_value must have a value to get here + auto data = q_scale_value->get()->Data(); + if (!std::all_of(data.begin(), data.end(), [](auto value) { return value == 0; })) { + break; + } + } + } + } + + match = true; + + } while (false); + + return match; } // Adds rank to negative axes and checks that axes are unique and within [0, rank). Returns false if invalid. @@ -185,15 +409,6 @@ static bool NormalizeAndValidateAxes(std::vector& axes, size_t rank) { return true; } -static inline bool NormalizeAndValidateAxis(int64_t& axis, size_t rank) { - int64_t rank_int = gsl::narrow_cast(rank); - if (axis < 0) { - axis += rank_int; - } - - return axis >= 0 && axis < rank_int; -} - // Read int64 data from attribute or input, depending on whether model opset < provided opset static std::optional> ReadFromAttrOrInput(OptimizerCtx& ctx, api::NodeRef& node, std::string_view attr_name, size_t inp_index, @@ -425,7 +640,7 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. - dq_node = GetDQWithConstInitializerInput(ctx.graph, input); + dq_node = GetDQWithConstInitializerInputAndSingleConsumer(ctx.graph, input); if (dq_node) { // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input constant_dq_input = dq_node->Inputs()[0]; @@ -447,19 +662,6 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons // to counteract its effect. If they later Unsqueeze the same input, the Squeeze nodes will simply be deleted // (see Case 2). if (consumers->nodes.size() > 0) { - // record the consumer node input as being special cased for use in Case 2 if a DQ node, and IsConstant - for (auto& consumer : consumers->nodes) { - auto& consumer_node_inputs = ctx.nodes_using_updated_shared_initializer[consumer->Id()]; - - // find input id/s for consumer - auto consumer_inputs = consumer->Inputs(); - for (size_t input_idx = 0; input_idx < consumer_inputs.size(); ++input_idx) { - if (consumer_inputs[input_idx] == value_to_modify) { - consumer_node_inputs.push_back(input_idx); - } - } - } - auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", value_to_modify, axes); api::NodeRef& squeeze = *squeeze_ptr; std::string_view sq_out = squeeze.Outputs()[0]; @@ -481,19 +683,8 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons // Case 2: input is a Squeeze node with matching axes std::unique_ptr inp_node = ctx.graph.GetNodeProducingOutput(input); - // check if this is a special-cased DQ node where we put the Squeeze on input 0 of the DQ in 'Case 1' above - if (inp_node && inp_node->OpType() == "DequantizeLinear" && - std::find_if(ctx.nodes_using_updated_shared_initializer.begin(), - ctx.nodes_using_updated_shared_initializer.end(), - [&inp_node](const auto& entry) { - const auto id = entry.first; - const auto& input_idxs = entry.second; - // check Id matches and the entry was for input 0 of the DQ node - return id == inp_node->Id() && - std::find(input_idxs.begin(), input_idxs.end(), size_t(0)) != input_idxs.end(); - }) != ctx.nodes_using_updated_shared_initializer.end()) { - // set things up so we can look past the DQ node to the Squeeze that was inserted in front of the reshaped - // constant initializer that was shared with this node. + // look past a DQ node for a Squeeze to cancel + if (inp_node && inp_node->OpType() == "DequantizeLinear") { dq_node = std::move(inp_node); auto dq_input = dq_node->Inputs()[0]; inp_node = ctx.graph.GetNodeProducingOutput(dq_input); @@ -558,6 +749,10 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons } node.SetInput(i, unsq_out); + + if (inp_node != nullptr && inp_node->OpType() == "DequantizeLinear") { + MakeQDQNodeUnit(ctx.graph, *inp_node); + } } static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::TensorRef& constant, @@ -585,10 +780,8 @@ static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::Ten // Replaces ith input to node with transposed value. Might create a new Transpose node, find an existing one, // or transpose an initializer. -static void TransposeInputImpl(api::GraphRef& graph, - NodeIdToInputIdxsMap* nodes_using_updated_shared_initializer, - api::NodeRef& node, size_t i, const std::vector& perm, - const std::vector& perm_inv) { +static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t i, + const std::vector& perm, const std::vector& perm_inv) { std::string_view input = node.Inputs()[i]; // Only local constants are editable @@ -602,7 +795,7 @@ static void TransposeInputImpl(api::GraphRef& graph, // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. - dq_node = GetDQWithConstInitializerInput(graph, input); + dq_node = GetDQWithConstInitializerInputAndSingleConsumer(graph, input); if (dq_node) { // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input constant_dq_input = dq_node->Inputs()[0]; @@ -660,22 +853,6 @@ static void TransposeInputImpl(api::GraphRef& graph, if (consumers->nodes.size() > 0) { // Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv // to counteract the effect. These Transposes will hopefully be optimized out later. - - // record the consumer node's input as being special cased for use in Case 2 if a DQ node, and IsConstant - if (nodes_using_updated_shared_initializer) { - for (auto& consumer : consumers->nodes) { - auto& consumer_node_inputs = (*nodes_using_updated_shared_initializer)[consumer->Id()]; - - // find input id/s for consumer - auto consumer_inputs = consumer->Inputs(); - for (size_t input_idx = 0; input_idx < consumer_inputs.size(); ++input_idx) { - if (consumer_inputs[input_idx] == constant_to_modify) { - consumer_node_inputs.push_back(input_idx); - } - } - } - } - auto transpose_inv_ptr = MakeTranspose(graph, constant_to_modify, perm_inv); api::NodeRef& transpose_inv = *transpose_inv_ptr; std::string_view transpose_out = transpose_inv.Outputs()[0]; @@ -696,19 +873,8 @@ static void TransposeInputImpl(api::GraphRef& graph, // Case 2: input is a Transpose node std::unique_ptr inp_node = graph.GetNodeProducingOutput(input); - // check if this is a special-cased DQ node where we put the Transpose on input 0 of the DQ in 'Case 1' above - if (inp_node && inp_node->OpType() == "DequantizeLinear" && - nodes_using_updated_shared_initializer && - std::find_if(nodes_using_updated_shared_initializer->begin(), nodes_using_updated_shared_initializer->end(), - [&inp_node](const auto entry) { - const auto id = entry.first; - const auto& input_idxs = entry.second; - // id matches and the entry is for input 0 of the DQ node - return id == inp_node->Id() && - std::find(input_idxs.begin(), input_idxs.end(), size_t(0)) != input_idxs.end(); - }) != nodes_using_updated_shared_initializer->end()) { - // set things up so we can look past the DQ node to the Transpose that was inserted in front of the reshaped - // constant initializer that was shared with this node. + // Look past a DQ for the Transpose + if (inp_node && inp_node->OpType() == "DequantizeLinear") { dq_node = std::move(inp_node); auto dq_input = dq_node->Inputs()[0]; inp_node = graph.GetNodeProducingOutput(dq_input); @@ -739,12 +905,6 @@ static void TransposeInputImpl(api::GraphRef& graph, return; } - // NOTE: We expect the Transpose to cancel out when handling a special-cased DQ node that was originally - // connected to a shared constant initializer, so we don't expect to get here if dq_node is not nullptr. - // If there was a dq_node where the Transpose didn't cancel out we fall through to the next case - // so we retain the potential to cancel out for any other usages of the shared initializer. - assert(!dq_node); // assert in debug build to investigate. fall through to next case in release build to be safe. - if (!dq_node) { // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove // the other Transpose. @@ -762,6 +922,8 @@ static void TransposeInputImpl(api::GraphRef& graph, node.SetInput(i, transpose_out); return; + } else { + // fall through to regular processing if the Transpose prior to the DQ doesn't cancel out cleanly } } } @@ -788,19 +950,23 @@ static void TransposeInputImpl(api::GraphRef& graph, graph.GetValueInfo(transpose_out)->PermuteDims(perm); node.SetInput(i, transpose_out); + + if (inp_node && inp_node->OpType() == "DequantizeLinear") { + MakeQDQNodeUnit(graph, *inp_node); + } } +// this TransposeInput is used by the layout transformer to wrap a node in Transpose ops. +// there's no OptimizerCtx in that scenario void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm, const std::vector& perm_inv) { - // this TransposeInput is used by the layout transformer to wrap a node in Transpose ops. there's no OptimizerCtx - // in that scenario and we're not tracking special-cased DQ nodes as we only do that when pushing Transpose nodes. - TransposeInputImpl(graph, /* nodes_using_updated_shared_initializer */ nullptr, node, i, perm, perm_inv); + TransposeInputImpl(graph, node, i, perm, perm_inv); } static void TransposeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& perm, const std::vector& perm_inv) { - TransposeInputImpl(ctx.graph, &ctx.nodes_using_updated_shared_initializer, node, i, perm, perm_inv); + TransposeInputImpl(ctx.graph, node, i, perm, perm_inv); } // Unsqueezes inputs of node to have uniform rank. Returns false if input ranks are unknown or exceed the target rank. @@ -933,7 +1099,7 @@ static bool CanLikelyRemoveTranspose(const api::GraphRef& graph, api::NodeRef& t // return true if // - the value is a constant initializer // - the value is the output of a DQ node who's input is a constant initializer -// - UnsqueezeInput/TranposeInput can look past the DQ to update the constant initializer directly +// - UnsqueezeInput/TransposeInput can look past the DQ to update the constant initializer directly // - DQ node is currently ignored if it uses per-channel quantization // - supporting per-channel quantization requires modifying the scales and zero point data, which can be done // if/when there's a use-case to justify the development cost. @@ -942,37 +1108,21 @@ static bool CanLikelyRemoveTranspose(const api::GraphRef& graph, api::NodeRef& t // in-place update. if we push the same transpose through this node it should cancel out that Squeeze/Transpose // // in all these cases we expect pushing the transpose through to not require a runtime Transpose node -static bool IsConstant(const api::GraphRef& graph, const api::NodeRef& node, - size_t input_id, - std::string_view value_name, - const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { +static bool IsConstant(const api::GraphRef& graph, std::string_view value_name) { std::unique_ptr producer_node = graph.GetNodeProducingOutput(value_name); if (!producer_node) { - // initializer. may or may not be constant depending on whether it has a matching graph input + // initializer or graph input. + // initializer may or may not be constant depending on whether it has a matching graph input std::unique_ptr constant = graph.GetConstant(value_name); return constant != nullptr; } - auto node_id_to_check = node.Id(); - - // handle potentially looking past a DQ node + // look past a DQ node if (producer_node->OpType() == "DequantizeLinear") { - std::unique_ptr dq_node = GetDQWithConstInitializerInput(graph, value_name); + std::unique_ptr dq_node = GetDQWithConstInitializerInputAndSingleConsumer(graph, value_name); if (dq_node != nullptr) { - // DQ node pointing to an initializer that has not been updated in-place yet - return true; - } - - // could also be a DQ that was connected to a shared initializer that was updated in-place. - // update the info on the node/input index to check and fall through - node_id_to_check = producer_node->Id(); - input_id = 0; // can only be input 0 of a DQ node - } - - auto entry = nodes_using_updated_shared_initializer.find(node_id_to_check); - if (entry != nodes_using_updated_shared_initializer.end()) { - if (std::find(entry->second.begin(), entry->second.end(), input_id) != entry->second.end()) { + // DQ node pointing to an constant initializer return true; } } @@ -982,29 +1132,59 @@ static bool IsConstant(const api::GraphRef& graph, const api::NodeRef& node, // Estimates the cost of transposing an input. Currently uses rank heuristic. Negative if transpose is removed. // Feel free to improve as needed. -static int EstimateTransposeValueCost(const api::GraphRef& graph, const api::NodeRef& node, - size_t input_id, std::string_view input, - const std::vector& perm_inv, - const HandlerMap& extended_handlers, - const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { +static int EstimateTransposeValueCost(const api::GraphRef& graph, std::string_view input, + const std::vector& perm_inv, const HandlerMap& extended_handlers) { // Case 1: Transposing constants probably costs nothing. - if (IsConstant(graph, node, input_id, input, nodes_using_updated_shared_initializer)) { + if (IsConstant(graph, input)) { return 0; } // Case 2: Transposing a transpose either cancels it or composes the permutations. std::unique_ptr producer_node = graph.GetNodeProducingOutput(input); - if (producer_node != nullptr && producer_node->IsOp("Transpose")) { - std::optional> perm2 = GetPermAttrIfValid(*producer_node); - if (perm2 != std::nullopt) { - if (*perm2 == perm_inv && CanLikelyRemoveTranspose(graph, *producer_node, extended_handlers)) { - return -EstimateValueRank(graph, input); - } else { - return 0; + + if (producer_node != nullptr) { + // this handles cancelling out a Transpose or Squeeze added to a shared initializer that was updated + // by TransposeInputImpl Case 1 or UnqueezeInput Case 1. + // - if a shared initializer is not broadcast, we have -> Transpose -> DQ + // - if a shared initializer is broadcast, we have -> Transpose -> Squeeze -> DQ and need + // to look slightly further in the hopes of finding the Transpose. + // - in practice it's only necessary if the operator that we're looking to push the transpose through has + // more than 2 inputs, and at least one of them is broadcastable. When there are 2 inputs the input with + // the Transpose will have a negative weight. If we don't look past DQ -> Squeeze to find the Transpose + // on the other input the positive weight of the broadcast initializer will always be less as it's based on + // rank, so the total cost estimate will always be negative and we'll push the Transpose. + // onnx::Where may be the only operator that requires the look past Squeeze. + // + // look past a DQ as we do that in the TransposeInput/UnsqueezeInput handling. + // match onnx and contrib ops domain for Q/DQ while we have those ops in both domains. + if (producer_node->OpType() == "DequantizeLinear") { + auto dq_input_node = graph.GetNodeProducingOutput(producer_node->Inputs()[0]); + if (dq_input_node != nullptr) { + if (dq_input_node->OpType() == "Squeeze") { + auto squeeze_input_node = graph.GetNodeProducingOutput(dq_input_node->Inputs()[0]); + if (squeeze_input_node->OpType() == "Transpose") { + // we only want to set this if it is a Transpose as otherwise we're invalidating the cost given it is + // rank based and the Squeeze will change that. + producer_node = std::move(squeeze_input_node); + } + } else { + // DQ doesn't change the rank so we don't need to check the OpType of the DQ input + producer_node = std::move(dq_input_node); + } } } - } + if (producer_node->IsOp("Transpose")) { + std::optional> perm2 = GetPermAttrIfValid(*producer_node); + if (perm2 != std::nullopt) { + if (*perm2 == perm_inv && CanLikelyRemoveTranspose(graph, *producer_node, extended_handlers)) { + return -EstimateValueRank(graph, input); + } else { + return 0; + } + } + } + } // Case 3: We will likely need to add a transpose. return EstimateValueRank(graph, input); } @@ -1013,14 +1193,13 @@ static int EstimateTransposeValueCost(const api::GraphRef& graph, const api::Nod static int EstimateTransposeInputsCost(const api::GraphRef& graph, const api::NodeRef& node, const std::vector& perm_inv, const std::vector& input_indices, - const HandlerMap& extended_handlers, - const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { + const HandlerMap& extended_handlers) { auto inputs = node.Inputs(); int cost = 0; for (size_t j : input_indices) { - cost += EstimateTransposeValueCost(graph, node, j, inputs[j], perm_inv, extended_handlers, - nodes_using_updated_shared_initializer); + cost += EstimateTransposeValueCost(graph, inputs[j], perm_inv, extended_handlers); } + return cost; } @@ -1222,22 +1401,24 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con size_t rank = perm.size(); int64_t rank_int = gsl::narrow_cast(rank); - std::string_view input = node.Inputs()[i]; - auto constant = graph.GetConstant(input); + std::string_view input_name = node.Inputs()[i]; + auto constant = graph.GetConstant(input_name); if (constant != nullptr) { auto shape = constant->Shape(); if (shape.size() == 1 && (shape[0] == rank_int || shape[0] == 0)) { - Permute1DConstant(graph, node, *constant, i, input, perm); + Permute1DConstant(graph, node, *constant, i, input_name, perm); return; } } + // we don't check for a DQ input here as PermuteInput is only used for Resize (roi/scales/sizes) and Pad (pads) + // inputs that would never be quantized. std::string_view gather_indices_const = AddInitializerInt64(graph, /*shape*/ {rank_int}, perm); - std::vector gather_inputs{input, gather_indices_const}; + std::vector gather_inputs{input_name, gather_indices_const}; auto gather_ptr = graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; std::string_view gather_output = gather.Outputs()[0]; - graph.CopyValueInfo(input, gather_output); + graph.CopyValueInfo(input_name, gather_output); gather.SetAttributeInt("axis", 0); node.SetInput(i, gather_output); } @@ -2057,14 +2238,6 @@ static const std::unordered_map handler_ma {"Reshape", reshape_handler}, }; -constexpr bool IsOnnxDomain(std::string_view domain) { - return (domain == onnxruntime::kOnnxDomain) || (domain == onnxruntime::kOnnxDomainAlias); -} - -constexpr bool IsMSDomain(std::string_view domain) { - return domain == onnxruntime::kMSDomain; -} - static const HandlerInfo* GetHandler(api::NodeRef& node, const HandlerMap& extended_handlers) { std::string key; auto domain = node.Domain(); @@ -2095,14 +2268,12 @@ static int CalculateCost(const api::GraphRef& graph, const api::NodeRef& node, const std::unordered_set& outputs_leading_to_transpose, const HandlerInfo& info, const std::vector& input_indices, - const HandlerMap& extended_handlers, - const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { + const HandlerMap& extended_handlers) { // We require the input cost (number of transposes before the op) and the total cost to strictly decrease. // Strict decrease of the input cost ensures the optimization is stable, since the total cost decrease is just an // estimate (the transpose after the op may or may not cancel with a subsequent transpose). We don't want // repeated runs of the optimizer to have a transpose toggle between two inputs of a binary op. - int cost = EstimateTransposeInputsCost(graph, node, perm, input_indices, extended_handlers, - nodes_using_updated_shared_initializer); + int cost = EstimateTransposeInputsCost(graph, node, perm, input_indices, extended_handlers); if (cost < 0 && info.transposes_outputs) { // If the output will be transposed and won't ultimately cancel, factor in that cost. @@ -2127,19 +2298,18 @@ static int CalculateCost(const api::GraphRef& graph, const api::NodeRef& node, } // Default cost check. Returns `true` if pushing the Transpose through the node is considered to be beneficial. -static bool ShouldPushTranspose(const api::GraphRef& graph, const api::NodeRef& node, - const std::vector& perm, - const std::unordered_set& outputs_leading_to_transpose, - const HandlerInfo& info, - const std::vector transposable_input_indices, - const HandlerMap& extended_handlers, - const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { +static bool DefaultCostCheck(const api::GraphRef& graph, const api::NodeRef& node, + const std::vector& perm, + const std::unordered_set& outputs_leading_to_transpose, + const HandlerInfo& info, + const std::vector transposable_input_indices, + const HandlerMap& extended_handlers) { if (node.IsOp("Transpose")) { return true; } int cost = CalculateCost(graph, node, perm, outputs_leading_to_transpose, info, transposable_input_indices, - extended_handlers, nodes_using_updated_shared_initializer); + extended_handlers); return cost < 0; } @@ -2165,8 +2335,8 @@ bool ProcessTranspose(OptimizerCtx& ctx, api::NodeRef& transpose, api::NodeRef& } if (cost == CostCheckResult::kFallThrough) { - cost = ShouldPushTranspose(ctx.graph, node, perm, outputs_leading_to_transpose, *info, input_indices, - ctx.extended_handlers, ctx.nodes_using_updated_shared_initializer) + cost = DefaultCostCheck(ctx.graph, node, perm, outputs_leading_to_transpose, *info, input_indices, + ctx.extended_handlers) ? CostCheckResult::kPushTranspose : CostCheckResult::kStop; } @@ -2200,7 +2370,7 @@ std::optional MakeOptimizerContext(api::GraphRef& graph, return std::nullopt; } - OptimizerCtx ctx{*opset, graph, provider_type, cost_check_fn, extended_handlers, {}}; + OptimizerCtx ctx{*opset, graph, provider_type, cost_check_fn, extended_handlers}; return ctx; } @@ -2320,77 +2490,99 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { } } } - if (!have_dq) { result.graph_modified = changed; return result; } - // Run second optimization pass. - // If any transpose succeeds a DQ node, move it above the DQ node if it's not part of a QDQ node group. - // In QDQ models this helps to preserve the QDQ node group when a Transpose was pushed across a DQ into - // an existing QDQ node group. - // In all other scenarios this is beneficial as well because moving transpose above DQ node is more efficient as - // transpose node now handles less data. + // Run 'fix up' pass for QDQ node units. + // + // Repair broken QDQ node unit from Transpose being blocked on Op inside a QDQ node unit. + // DQ -> Transpose -> Op -> Q => + // DQ -> Transpose -> Q -> DQ -> Op -> Q + // + // Create QDQ node unit for Transpose after DQ that provides graph output. + // DQ -> Transpose -> graph output => + // DQ -> Transpose -> Q -> DQ -> graph output + // + // Remove empty DQ -> Q pair from moving a Transpose downstream or a Transpose being cancelled out. + // DQ -> Q -> consumer node => + // consumer node + auto graph_nodes = ctx.graph.Nodes(); for (size_t i = 1; i < graph_nodes.size(); i++) { - const auto& node = *graph_nodes[i]; + auto& node = *graph_nodes[i]; if (!can_modify_node(node)) { continue; } - if (node.OpType() == "Transpose") { - auto& transpose_node = *graph_nodes[i]; - auto dq_node = ctx.graph.GetNodeProducingOutput(transpose_node.Inputs()[0]); - if (!dq_node || dq_node->OpType() != "DequantizeLinear") { + for (size_t i_idx = 0, i_end = node.Inputs().size(); i_idx < i_end; ++i_idx) { + // any change requires a DQ as the input to the current node + auto input_node = ctx.graph.GetNodeProducingOutput(node.Inputs()[i_idx]); + if (!input_node || input_node->OpType() != "DequantizeLinear") { continue; } - // Check if Transpose node is the only consumer of dq node - auto consumers_of_dq_node = ctx.graph.GetValueConsumers(dq_node->Outputs()[0]); - if (!consumers_of_dq_node->comprehensive || consumers_of_dq_node->nodes.size() > 1) { - continue; - } + auto& dq_node = *input_node; + std::unique_ptr single_consumer_node; + + // remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp. + if (node.OpType() == "QuantizeLinear") { + // we don't need to check scale and zp inputs, and we may remove nodes invalidating `node` if we + // continue with the loop of inputs so set i_end to bail + i_end = 1; + + auto& q_node = node; + if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) && + OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) && + CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { + // connect Q consumer to DQ input + for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) { + if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) { + single_consumer_node->SetInput(j_idx, dq_node.Inputs()[0]); + // break; in theory the Q might be providing multiple inputs. + } + } - auto consumers_of_transpose_node = ctx.graph.GetValueConsumers(transpose_node.Outputs()[0]); - bool is_part_of_qdq_group = std::find_if(consumers_of_transpose_node->nodes.cbegin(), - consumers_of_transpose_node->nodes.cend(), - [](const std::unique_ptr& node) { - return node->OpType() == "QuantizeLinear"; - }) != consumers_of_transpose_node->nodes.cend(); - if (is_part_of_qdq_group) { - continue; - } + // disconnect other nodes and remove + dq_node.SetInput(0, ""); + q_node.SetInput(0, ""); + ctx.graph.RemoveNode(dq_node); + ctx.graph.RemoveNode(q_node); - // Update Dequantize Node and move the transpose above it - auto perm = GetPermAttrIfValid(transpose_node); - if (!perm.has_value()) { - continue; + changed = true; + continue; + } } - // we're moving the Transpose to before the DQ, so we need to use the inverse permutations to update the axis - // attribute correctly when doing per-axis dequantization - std::string_view dq_domain = dq_node->Domain(); - std::vector perm_inv = InvertPerm(*perm); - - if (IsOnnxDomain(dq_domain) && !HandleQuantizeDequantizeAxis(ctx.graph, perm_inv, *dq_node, ctx.opset)) { - continue; - } + // DQ -> Transpose => DQ -> Transpose -> Q -> DQ if needed + if (node.OpType() == "Transpose") { + auto& transpose_node = node; - // NOTE: this bleeds ORT specific logic into the base optimizer, however we justify that for now because we expect - // the types that the ORT DQ provides to be added to the ONNX spec, at which point this special case can go away. - if (IsMSDomain(dq_domain) && !TransposeQuantizeDequantizeAxis(ctx.graph, perm_inv, *dq_node)) { - continue; - } + // GetValueConsumers sets `comprehensive` to false for graph outputs and implicit inputs. + // we know Transpose doesn't have implicit inputs so if nodes are empty it can only be a graph output. + auto transpose_output = transpose_node.Outputs()[0]; + auto consumers = ctx.graph.GetValueConsumers(transpose_output); + if (consumers->nodes.empty()) { + // DQ -> Transpose -> graph output + } else { + if (consumers->nodes.size() > 1) { + // unexpected to have DQ -> Transpose -> multiple consumers + continue; + } - TransposeFirstInput(ctx, *dq_node, *perm); + if (consumers->nodes[0]->OpType() == "QuantizeLinear") { + // already in QDQ node unit + continue; + } + } - // remove existing transpose node - transpose_node.SetInput(0, ""); - ctx.graph.MoveOutput(transpose_node, 0, *dq_node, 0); - ctx.graph.RemoveNode(transpose_node); - changed = true; + // Add Q -> DQ after the DQ -> Transpose + if (MakeQDQNodeUnit(ctx.graph, dq_node)) { + changed = true; + } + } } } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index cc1552704c187..6d1f1f8535ba4 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -51,32 +51,6 @@ struct OptimizerCtx { // Handlers for ops that are not in the ONNX opset, or for ONNX ops where special handling is required. // If a handler is not found in this map, the default handlers will be used. const HandlerMap& extended_handlers; - - // When we update a shared constant initializer as part of pushing a transpose through a node we update the - // initializer in-place and insert Squeeze (in UnsqueezeInput if the initializer is broadcast) or - // Transpose (in TransposeInput) nodes between the updated initializer and the other usages. - // This map contains the set of nodes that had a Squeeze or Transpose added between them and the initializer. - // The entry contains the node id (key) and original input index/es (value) that were connected to the initializer - // prior to the insertion of the Squeeze/Transpose. - // - // Assuming we also transpose the other usages of the initializer in the same way (which would be expected) the - // Squeeze and Transpose nodes would be cancelled out, and the other usages will end up using the original - // initializer that was updated in-place. - // - // We use this information in two ways. - // - // 1. In the IsConstant calculation that determines the cost of pushing a transpose through a node. - // - as we expect the transpose to be making the same modification to all shared usages of the initializer we - // expect the Squeeze/Transpose nodes to be cancelled out, resulting in no runtime cost to push the transpose - // through that input. - // - // 2. To enable and track a special case in a QDQ format model where there is the added complexity of a DQ node - // between the initializer and each usage. - // - we look past a DQ node in UnsqueezeInput and TransposeInput to determine if there is a constant initializer - // that can be updated in-place as the DQ node is not sensitive to any rank or layout changes - // - NOTE we currently ignore DQ nodes with per-channel quantization as they are sensitive to changes - // - we also look past DQ nodes when processing the other usages in order to cancel out the Squeeze/Transpose - NodeIdToInputIdxsMap nodes_using_updated_shared_initializer; }; /// diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index fb338be1c7f5a..c45aaef0cf02f 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -442,6 +442,13 @@ class GraphRef { return !unused; } + /// + /// Is the value a graph output. + /// + /// Value name. + /// True if output of the Graph. + virtual bool IsGraphOutput(std::string_view name) const = 0; + virtual ~GraphRef(){}; }; diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index 2fcb88cb0b9ba..d9f08ffe1171e 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -107,10 +107,17 @@ class ApiGraph final : public api::GraphRef { onnxruntime::Graph& graph_; AllocatorPtr cpu_allocator_; const char* new_node_ep_; + std::unordered_set graph_outputs_; // graph_.GetOutputs() names for efficient lookup public: explicit ApiGraph(onnxruntime::Graph& graph, AllocatorPtr cpu_allocator, const char* new_node_ep) - : graph_(graph), cpu_allocator_(std::move(cpu_allocator)), new_node_ep_(new_node_ep) {} + : graph_(graph), cpu_allocator_(std::move(cpu_allocator)), new_node_ep_(new_node_ep) { + const auto& graph_outputs = graph_.GetOutputs(); + graph_outputs_.reserve(graph_outputs.size()); + for (const auto* output : graph_outputs) { + graph_outputs_.insert(output->Name()); + } + } onnxruntime::Graph& Graph() { return graph_; @@ -138,6 +145,7 @@ class ApiGraph final : public api::GraphRef { void MoveOutput(api::NodeRef& src_node, size_t src_idx, api::NodeRef& dst_node, size_t dst_idx) override; void CopyValueInfo(std::string_view src_name, std::string_view dst_name) override; bool HasValueConsumers(std::string_view name) const override; + bool IsGraphOutput(std::string_view name) const override; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiGraph); @@ -447,6 +455,10 @@ std::vector> ApiGraph::Nodes() const { return nodes; } +bool ApiGraph::IsGraphOutput(std::string_view name) const { + return graph_outputs_.find(name) != graph_outputs_.end(); +} + std::unique_ptr ApiGraph::GetConstant(std::string_view name) const { const auto* tensor = graph_.GetConstantInitializer(std::string(name), /*check_outer_scope*/ true); if (tensor == nullptr) { @@ -494,11 +506,8 @@ std::unique_ptr ApiGraph::GetValueConsumers(std::string_vie } } - const auto& graph_outputs = graph_.GetOutputs(); - for (const auto* output : graph_outputs) { - if (output->Name() == name) { - consumers->comprehensive = false; - } + if (IsGraphOutput(name)) { + consumers->comprehensive = false; } return consumers; @@ -510,14 +519,7 @@ bool ApiGraph::HasValueConsumers(std::string_view name) const { return true; } - const auto& graph_outputs = graph_.GetOutputs(); - for (const auto* output : graph_outputs) { - if (output->Name() == name) { - return true; - } - } - - return false; + return IsGraphOutput(name); } std::unique_ptr ApiGraph::GetNodeProducingOutput(std::string_view name) const { @@ -704,10 +706,6 @@ static std::optional GetLayoutTransformationPotentiallyAddedOpSinceVersion( // Based on the opset version imported for this model, returns the since version for the node. static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view domain, const std::unordered_map& domain_to_version_map) { - // TODO do we need this check? we will also check kLayoutTransformationPotentiallyAddedOps - ORT_ENFORCE(domain == kOnnxDomain, "Transpose optimizer is expected to add only onnx domain ops. Domain: ", - domain, " provided for op: ", op_type); - const auto opset_import_iter = domain_to_version_map.find(std::string(domain)); ORT_ENFORCE(opset_import_iter != domain_to_version_map.end(), domain, " domain not found in opset imports."); diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc new file mode 100644 index 0000000000000..c454a2a779f6e --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/builders/impl/base_op_builder.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" + +#ifdef __APPLE__ +#include "core/providers/coreml/builders/model_builder.h" +#endif +#include "core/providers/coreml/builders/op_builder_factory.h" + +namespace onnxruntime { +namespace coreml { + +class SoftmaxOpBuilder : public BaseOpBuilder { + // Add operator related +#ifdef __APPLE__ + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; +#endif + + // Operator support related + private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; + +// Add operator related + +#ifdef __APPLE__ + +Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + std::unique_ptr layer = CreateNNLayer(model_builder, node); + const auto& input_name = node.InputDefs()[0]->Name(); + const auto& output_name = node.OutputDefs()[0]->Name(); + + std::vector data_shape; + ORT_RETURN_IF_NOT(GetStaticShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); + + NodeAttrHelper helper(node); + int32_t axis_default_value = (node.SinceVersion() < 13) ? 1 : -1; + const auto axis = helper.Get("axis", axis_default_value); + const auto axis_nonnegative = HandleNegativeAxis(axis, data_shape.size()); + + if (node.SinceVersion() >= 13 || (data_shape.size() == 2)) { + auto* coreml_softmaxnd = layer->mutable_softmaxnd(); + coreml_softmaxnd->set_axis(axis); + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(layer)); + } else { + // note: if opsets < 13, onnx Softmax coerces the input shape to be 2D based on axis. + // we need to manually reshape to 2D and apply SoftmaxND to axis -1 to achieve equivalent results for CoreML. + TensorShape input_shape(data_shape); + const auto size_to_dimension = input_shape.SizeToDimension(axis_nonnegative); + const auto size_from_dimension = input_shape.SizeFromDimension(axis_nonnegative); + + TensorShapeVector target_shape; + target_shape.push_back(size_to_dimension); + target_shape.push_back(size_from_dimension); + + const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output")); + { // Add reshape layer + const auto softmax_reshape1_layer_name = + model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape1")); + auto reshape_layer = CreateNNLayer(softmax_reshape1_layer_name); + *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; + *reshape_layer->mutable_input()->Add() = input_name; + *reshape_layer->mutable_output()->Add() = reshape1_output_name; + model_builder.AddLayer(std::move(reshape_layer)); + } + const auto softmax_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "softmax_output")); + { + auto* coreml_softmaxnd = layer->mutable_softmaxnd(); + coreml_softmaxnd->set_axis(-1); + *layer->mutable_input()->Add() = reshape1_output_name; + *layer->mutable_output()->Add() = softmax_output_name; + model_builder.AddLayer(std::move(layer)); + } + { + // Add reshape back layer + const auto softmax_reshape2_layer_name = + model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape2")); + auto reshape_layer = CreateNNLayer(softmax_reshape2_layer_name); + *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; + *reshape_layer->mutable_input()->Add() = softmax_output_name; + *reshape_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(reshape_layer)); + } + } + + return Status::OK(); +} + +#endif + +// Operator support related + +bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetStaticShape(*input_defs[0], input_shape, logger)) + return false; + + const TensorShape shape(input_shape); + if (shape.Size() == 0) { + LOGS(logger, VERBOSE) << "Empty input data is not supported."; + return false; + } + + return true; +} + +void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc new file mode 100644 index 0000000000000..815f68128ffaf --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/builders/impl/base_op_builder.h" + +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" + +#if defined(__APPLE__) +#include "core/providers/coreml/builders/model_builder.h" +#endif + +namespace onnxruntime { +namespace coreml { + +class SplitOpBuilder : public BaseOpBuilder { + // Add operator related +#ifdef __APPLE__ + private: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; +#endif + + // Operator support related + private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + + // Split opset 13- uses "split" as attribute. Currently it's not supported. + int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } +}; + +// Add operator related + +#ifdef __APPLE__ + +void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + const auto& input_defs = node.InputDefs(); + + if (input_defs.size() > 1 && input_defs[1]->Exists()) { // optional second input "split" + model_builder.AddInitializerToSkip(input_defs[1]->Name()); + } +} + +Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector data_shape; + ORT_RETURN_IF_NOT(GetShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); + + NodeAttrHelper helper(node); + const auto axis = helper.Get("axis", 0); + + // attribute introduced since opset 18 + uint64_t num_outputs; + + std::unique_ptr layer = CreateNNLayer(model_builder, node); + auto* coreml_splitnd = layer->mutable_splitnd(); + coreml_splitnd->set_axis(axis); + + if (input_defs.size() > 1) { + // if "split" is explicitly provided as an input + const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + Initializer unpacked_tensor(split_tensor); + auto split_span = unpacked_tensor.DataAsSpan(); + auto split_sizes = split_span.size(); + num_outputs = narrow(split_sizes); + for (size_t i = 0; i < split_sizes; i++) { + coreml_splitnd->add_splitsizes(split_span[i]); + } + } else if (node.SinceVersion() < 18) { + num_outputs = narrow(node.OutputDefs().size()); + coreml_splitnd->set_numsplits(num_outputs); + } else { + // note: for opset 18+ 'num_outputs' is a required attribute + num_outputs = narrow(helper.GetInt("num_outputs").value()); + // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists + auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; + uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); + uint64_t remainder = split_dim_size % chunk_size; + if (remainder) { + // uneven + auto split_sizes = InlinedVector(num_outputs, chunk_size); + split_sizes.back() = remainder; + for (size_t i = 0; i < split_sizes.size(); i++) { + coreml_splitnd->add_splitsizes(split_sizes[i]); + } + } else { + // even + coreml_splitnd->set_numsplits(num_outputs); + } + } + + *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); + // variadic number of outputs. Calculated based on the length of the given splitSizes if provided. + // Otherwise, uses attribute value 'num_outputs'. + for (uint64_t i = 0; i < num_outputs; i++) { + *layer->mutable_output()->Add() = node.OutputDefs()[i]->Name(); + } + model_builder.AddLayer(std::move(layer)); + + return Status::OK(); +} + +#endif + +// Operator support related + +bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); + + NodeAttrHelper helper(node); + const auto axis = helper.Get("axis", 0); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())]; + if (input_defs.size() > 1 && input_defs[1]->Exists()) { + if (!CheckIsConstantInitializer(*input_defs[1], input_params.graph_viewer, logger, "'split'")) { + return false; + } + const auto split_shape = *input_defs[1]->Shape(); + if (split_shape.dim_size() < 2) { + LOGS(logger, VERBOSE) << "CoreML SplitND requires to produce at least 2 outputs."; + return false; + } + const auto& splits_tensor = *initializers.at(input_defs[1]->Name()); + Initializer unpacked_tensor(splits_tensor); + auto splits_span = unpacked_tensor.DataAsSpan(); + int sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), 0); + if (sum_of_splits != split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Mismatch between the sum of 'split'. Expected: " + << split_dims_at_axis + << "Actual: " + << sum_of_splits; + return false; + } + auto it = std::find(splits_span.begin(), splits_span.end(), 0); + if (it != splits_span.end()) { + LOGS(logger, VERBOSE) << "Invalid value in 'splits' input."; + return false; + } + if (split_dims_at_axis == -1) { + LOGS(logger, VERBOSE) << "Dim at the splitting axis is not allowed to be dynamic."; + return false; + } + } else { + if (node.SinceVersion() >= 18) { + const auto num_outputs = helper.GetInt("num_outputs"); + if (!num_outputs.has_value()) { + LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; + return false; + } + if (num_outputs.value() < 2) { + LOGS(logger, VERBOSE) << "Invalid num_outputs. The value cannot be lower than 2.\n" + << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value(); + return false; + } + if (num_outputs.value() != static_cast(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n." + << "The value should be smaller or equal to the size of dimension being split. num_outputs: " + << num_outputs.value(); + return false; + } + } + } + return true; +} + +void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index c1b09cec8a30a..2c06659852134 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -122,6 +122,14 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateSliceOpBuilder("Slice", op_registrations); } + { // Softmax + CreateSoftmaxOpBuilder("Softmax", op_registrations); + } + + { // Split + CreateSplitOpBuilder("Split", op_registrations); + } + return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index b2c8dc765d33d..d72420bcfff88 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -36,6 +36,8 @@ void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.cc b/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.cc index f36b75c508da0..eb245a4c9ba0c 100644 --- a/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.cc +++ b/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.cc @@ -141,14 +141,11 @@ struct TfIdfVectorizer::Impl { Impl(const Impl&) = delete; Impl& operator=(const Impl&) = delete; - void IncrementCount(size_t ngram_id, size_t row_num, - std::vector& frequencies) const { + inline size_t OutputIdToIncrement(size_t ngram_id) const { assert(ngram_id != 0); --ngram_id; assert(ngram_id < ngram_indexes_.size()); - size_t output_idx = row_num * output_size_ + SafeInt(ngram_indexes_[ngram_id]); - assert(output_idx < frequencies.size()); - ++frequencies[output_idx]; + return SafeInt(ngram_indexes_[ngram_id]); } }; @@ -252,77 +249,17 @@ TfIdfVectorizer::TfIdfVectorizer(const OpKernelInfo& info) : OpKernel(info), imp TfIdfVectorizer::~TfIdfVectorizer() = default; -void TfIdfVectorizer::OutputResult(OpKernelContext* ctx, size_t B, const std::vector& frequences) const { - const Impl& impl = *impl_; - std::vector output_dims; - if (B == 0) { - output_dims.push_back(impl.output_size_); - B = 1; // For use in the loops below - } else { - output_dims.push_back(B); - output_dims.push_back(impl.output_size_); - } - - const auto row_size = impl.output_size_; - - TensorShape output_shape(output_dims); - assert(frequences.size() == static_cast(output_shape.Size())); - - auto Y = ctx->Output(0, output_shape); - auto output_data = Y->MutableData(); - const auto& w = impl.weights_; - switch (impl.weighting_criteria_) { - case kTF: { - for (auto f : frequences) { - *output_data++ = static_cast(f); - } - } break; - case kIDF: { - if (!w.empty()) { - const auto* freqs = frequences.data(); - for (size_t batch = 0; batch < B; ++batch) { - for (size_t i = 0; i < row_size; ++i) { - *output_data++ = (*freqs++ > 0) ? w[i] : 0; - } - } - } else { - for (auto f : frequences) { - *output_data++ = (f > 0) ? 1.0f : 0; - } - } - } break; - case kTFIDF: { - if (!w.empty()) { - const auto* freqs = frequences.data(); - for (size_t batch = 0; batch < B; ++batch) { - for (size_t i = 0; i < row_size; ++i) { - *output_data++ = *freqs++ * w[i]; - } - } - } else { - for (auto f : frequences) { - *output_data++ = static_cast(f); - } - } - } break; - case kNone: // fall-through - default: - assert(false); - } -} - -void TfIdfVectorizer::ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_t row_size, - std::vector& frequencies) const { - auto X = ctx->Input(0); - const auto elem_size = X->DataType()->Size(); - - const void* const row_begin = AdvanceElementPtr(X->DataRaw(), row_num * row_size, elem_size); +void TfIdfVectorizer::ComputeImpl(const void* x_data_raw, size_t elem_size, ptrdiff_t row_num, size_t row_size, + bool is_input_string, gsl::span output_data, + std::function&)>& fn_weight) const { + const void* const row_begin = AdvanceElementPtr(x_data_raw, row_num * row_size, elem_size); const void* const row_end = AdvanceElementPtr(row_begin, row_size, elem_size); const auto& impl = *impl_; const auto max_gram_length = impl.max_gram_length_; const auto max_skip_distance = impl.max_skip_count_ + 1; // Convert to distance auto start_ngram_size = impl.min_gram_length_; + size_t output_idx; for (auto skip_distance = 1; skip_distance <= max_skip_distance; ++skip_distance) { auto ngram_start = row_begin; @@ -336,7 +273,7 @@ void TfIdfVectorizer::ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_ } auto ngram_item = ngram_start; - if (X->IsDataTypeString()) { + if (is_input_string) { const std::string* str_item = reinterpret_cast(ngram_item); const StrMap* str_map = &impl.str_map_; for (auto ngram_size = 1; @@ -349,7 +286,8 @@ void TfIdfVectorizer::ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_ break; } if (ngram_size >= start_ngram_size && hit->second->id_ != 0) { - impl.IncrementCount(hit->second->id_, row_num, frequencies); + output_idx = impl.OutputIdToIncrement(hit->second->id_); + fn_weight(output_idx, output_data); } str_map = &hit->second->leafs_; } @@ -360,13 +298,14 @@ void TfIdfVectorizer::ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_ ngram_size <= max_gram_length && ngram_item < ngram_row_end; ++ngram_size, ngram_item = AdvanceElementPtr(ngram_item, skip_distance, elem_size)) { - int64_t val = (X->IsDataType()) ? int64_t{*reinterpret_cast(ngram_item)} : *reinterpret_cast(ngram_item); + int64_t val = (elem_size == 4) ? int64_t{*reinterpret_cast(ngram_item)} : *reinterpret_cast(ngram_item); auto hit = int_map->find(val); if (hit == int_map->end()) { break; } if (ngram_size >= start_ngram_size && hit->second->id_ != 0) { - impl.IncrementCount(hit->second->id_, row_num, frequencies); + output_idx = impl.OutputIdToIncrement(hit->second->id_); + fn_weight(output_idx, output_data); } int_map = &hit->second->leafs_; } @@ -412,31 +351,76 @@ Status TfIdfVectorizer::Compute(OpKernelContext* ctx) const { } assert((num_rows * C) == total_items); - // Frequency holder allocate [B..output_size_] - // and init all to zero - std::vector frequencies; - frequencies.resize(num_rows * impl_->output_size_, 0); + const Impl& impl = *impl_; + TensorShapeVector output_dims; + if (B == 0) { + output_dims.push_back(impl.output_size_); + B = 1; // For use in the loops below + } else { + output_dims.push_back(B); + output_dims.push_back(impl.output_size_); + } + TensorShape output_shape(output_dims); + + auto Y = ctx->Output(0, output_shape); + auto output_data = Y->MutableData(); + const bool is_input_string = X->IsDataTypeString(); if (total_items == 0 || - (X->IsDataTypeString() && impl_->str_map_.empty()) || + (is_input_string && impl_->str_map_.empty()) || ((X->IsDataType() || X->IsDataType()) && impl_->int64_map_.empty())) { // TfidfVectorizer may receive an empty input when it follows a Tokenizer // (for example for a string containing only stopwords). // TfidfVectorizer returns a zero tensor of shape // {b_dim, output_size} when b_dim is the number of received observations // and output_size the is the maximum value in ngram_indexes attribute plus 1. - OutputResult(ctx, B, frequencies); + memset(output_data, 0, static_cast(output_shape.Size() * sizeof(float))); return Status::OK(); } - std::function fn = [this, ctx, C, &frequencies](ptrdiff_t row_num) { - ComputeImpl(ctx, row_num, C, frequencies); - }; + auto x_data_raw = ctx->Input(0)->DataRaw(); + const auto elem_size = X->DataType()->Size(); + int32_t num_batches = std::min(concurrency::ThreadPool::DegreeOfParallelism(ctx->GetOperatorThreadPool()) * 2, num_rows); - concurrency::ThreadPool::TryBatchParallelFor(ctx->GetOperatorThreadPool(), num_rows, std::move(fn), 0); + const auto& w = impl.weights_; + std::function&)> fn_weight; - OutputResult(ctx, B, frequencies); + switch (impl.weighting_criteria_) { + case kTF: + fn_weight = [](size_t i, gsl::span& out) { out[i] += 1.0f; }; + break; + case kIDF: + if (!w.empty()) { + fn_weight = [&w](size_t i, gsl::span& out) { out[i] = w[i]; }; + } else { + fn_weight = [](size_t i, gsl::span& out) { out[i] = 1.0f; }; + } + break; + case kTFIDF: + if (!w.empty()) { + fn_weight = [&w](size_t i, gsl::span& out) { out[i] += w[i]; }; + } else { + fn_weight = [](size_t i, gsl::span& out) { out[i] += 1.0f; }; + } + break; + case kNone: // fall-through + default: + assert(false); + } + + std::function fn = [this, C, output_data, x_data_raw, elem_size, + is_input_string, num_batches, num_rows, &fn_weight](ptrdiff_t batch_num) { + // Frequency holder allocate [B..output_size_] and init all to zero. + auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_batches, static_cast(num_rows)); + std::vector frequencies(this->impl_->output_size_); + for (auto row_num = work.start; row_num < work.end; ++row_num) { + auto out = gsl::span(output_data + row_num * this->impl_->output_size_, this->impl_->output_size_); + std::fill(out.begin(), out.end(), 0.0f); + ComputeImpl(x_data_raw, elem_size, row_num, C, is_input_string, out, fn_weight); + } + }; + concurrency::ThreadPool::TrySimpleParallelFor(ctx->GetOperatorThreadPool(), num_batches, std::move(fn)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.h b/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.h index 45db40d893231..14488d91c23e9 100644 --- a/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.h +++ b/onnxruntime/core/providers/cpu/nn/tfidfvectorizer.h @@ -19,11 +19,8 @@ class TfIdfVectorizer final : public OpKernel { Status Compute(OpKernelContext* ctx) const override; private: - void ComputeImpl(OpKernelContext* ctx, ptrdiff_t row_num, size_t row_size, - std::vector& frequencies) const; - - // Apply weighing criteria and output - void OutputResult(OpKernelContext* ctx, size_t b_dim, const std::vector& frequences) const; + void ComputeImpl(const void* x_data_raw, size_t elem_size, ptrdiff_t row_num, size_t row_size, bool is_input_string, + gsl::span output_data, std::function&)>& fn_weight) const; struct Impl; std::unique_ptr impl_; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 798244d7cb75b..68ceafb1d4bf6 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -353,6 +353,15 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 18, If); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -636,6 +645,15 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/batch_norm.cc b/onnxruntime/core/providers/js/operators/batch_norm.cc new file mode 100644 index 0000000000000..e18ad835792f7 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/batch_norm.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "batch_norm.h" + +namespace onnxruntime { +namespace js { + +#define REGISTER_BATCHNORM_KERNEL(OP_TYPE, DOMAIN, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, DOMAIN, 7, 8, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), KERNEL_CLASS); \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, DOMAIN, 9, 13, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), KERNEL_CLASS); \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX(OP_TYPE, DOMAIN, 14, 14, kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .TypeConstraint("U", JsepSupportedFloatTypes()), \ + KERNEL_CLASS); \ + ONNX_OPERATOR_KERNEL_EX(OP_TYPE, DOMAIN, 15, kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .TypeConstraint("T1", JsepSupportedFloatTypes()) \ + .TypeConstraint("T2", JsepSupportedFloatTypes()), \ + KERNEL_CLASS); + +REGISTER_BATCHNORM_KERNEL(BatchNormalization, kMSInternalNHWCDomain, BatchNorm); +REGISTER_BATCHNORM_KERNEL(BatchNormalization, kOnnxDomain, BatchNorm); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/batch_norm.h b/onnxruntime/core/providers/js/operators/batch_norm.h new file mode 100644 index 0000000000000..bb987a8aeab44 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/batch_norm.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +template +class BatchNorm final : public JsKernel { + public: + explicit BatchNorm(const OpKernelInfo& info) : JsKernel(info) { + float epsilon = info.GetAttrOrDefault("epsilon", 1e-5); + float momentum = info.GetAttrOrDefault("momentum", 0.9); + int64_t spatial = info.GetAttrOrDefault("spatial", 1); + + const auto& node = info.node(); + int opset = node.SinceVersion(); + int64_t training_mode = opset <= 9 ? info.GetOutputCount() > 1 : info.GetAttrOrDefault("training_mode", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(BatchNormalization, ({ + "epsilon" : $1, + "momentum" : $2, + "spatial" : !!$4, + "trainingMode" : !!$3, + "format" : $5 ? "NHWC" : "NCHW", + }), + static_cast(epsilon), static_cast(momentum), + static_cast(training_mode), static_cast(spatial), + static_cast(is_channels_last)); + } +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/concat.cc b/onnxruntime/core/providers/js/operators/concat.cc index 3a6a7e1cafd7a..17c6b0466c3a5 100644 --- a/onnxruntime/core/providers/js/operators/concat.cc +++ b/onnxruntime/core/providers/js/operators/concat.cc @@ -12,7 +12,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 1, 3, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -22,7 +23,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 4, 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -32,7 +34,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 11, 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -42,7 +45,8 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index b030efa238209..454f3dd5eb3cc 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -146,26 +146,15 @@ std::vector> GetCapability::Execute() { // If subgraph has less then three, graph is considered trivial if (this_cluster.size() < 3) { continue; - } else { - // If subgraph only has Identity node, EyeLike or Dropout, OpenVINO EP doesn't support it. - if (this_cluster.size() == 1) { - const auto& node = graph_viewer_.GetNode(this_cluster[0]); - if (IsOpSupportedOnlyInModel(node->OpType())) - continue; - // If reshape is not an intermediate node, shape needs to be an initializer - if (data_ops_->SpecialConditionForClusterSizeOne(ng_required_initializers, node)) - continue; - } } - std::vector cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs; + std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; GetInputsOutputsOfCluster(graph_viewer_, this_cluster, ng_required_initializers, cluster_graph_inputs, cluster_inputs, - const_inputs, cluster_outputs); bool omit_subgraph = false; diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index 74369d39b9a24..ee0bfddb7dc83 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -180,12 +180,12 @@ void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::unordered_set& ng_required_initializers, /*out*/ std::vector& cluster_graph_inputs, /*out*/ std::vector& cluster_inputs, - /*out*/ std::vector& constant_inputs, /*out*/ std::vector& cluster_outputs) { std::unordered_set input_args; std::vector ordered_input_args; std::unordered_set output_args; std::unordered_set external_output_args; + std::vector constant_inputs; for (const auto& node_idx : cluster) { const auto& node = graph_viewer.GetNode(node_idx); diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index c256cde97956e..b3edeef88dfec 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -45,7 +45,6 @@ void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::unordered_set& ng_required_initializers, /*out*/ std::vector& cluster_graph_inputs, /*out*/ std::vector& cluster_inputs, - /*out*/ std::vector& constant_inputs, /*out*/ std::vector& cluster_outputs); } // namespace openvino_ep diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index bd9986e661e21..234b957816662 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -60,10 +60,10 @@ Status CreateNodeArgs(const std::vector& names, return Status::OK(); } -Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger) { +Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger) { using namespace onnxruntime; std::shared_ptr model; ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); @@ -74,10 +74,10 @@ Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_m qnn_model); } -Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model) { +Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model) { const auto& node = graph_viewer.Nodes().begin(); NodeAttrHelper node_helper(*node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); @@ -89,11 +89,11 @@ Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewe } std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, ""); + std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); + std::filesystem::path context_binary_path = folder_path.append(external_qnn_context_binary_file_name); - std::string context_binary_path(std::filesystem::path(ctx_onnx_model_path).parent_path().string() + - "/" + external_qnn_context_binary_file_name); size_t buffer_size{0}; - std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary); + std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary); ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file."); cache_file.seekg(0, cache_file.end); @@ -112,114 +112,122 @@ Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewe qnn_model); } -Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::string& graph_partition_name, - std::string& cache_source, - const logging::Logger& logger) { - if (!is_metadata_ready_) { - using namespace onnxruntime; - std::shared_ptr model; - ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); - const auto& graph = GraphViewer(model->MainGraph()); - const auto& node = graph.Nodes().begin(); - NodeAttrHelper node_helper(*node); - model_name_ = graph.Name(); - model_description_ = graph.Description(); - graph_partition_name_ = node_helper.Get(PARTITION_NAME, ""); - cache_source_ = node_helper.Get(SOURCE, ""); - is_metadata_ready_ = true; +Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, + const onnxruntime::PathString& ctx_onnx_model_path, + bool is_qnn_ctx_model, + bool is_ctx_cache_file_exist, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger) { + Status status; + if (is_qnn_ctx_model) { + status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); + } else if (is_ctx_cache_file_exist) { + status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); + } + + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); } - model_name = model_name_; - model_description = model_description_; - graph_partition_name = graph_partition_name_; - cache_source = cache_source_; return Status::OK(); } -bool QnnCacheModelHandler::IsContextCacheFileExists(const std::string& customer_context_cache_path, - const std::string& model_description, - const onnxruntime::PathString& model_pathstring) { - // Avoid duplicate work - if (ctx_file_exists_) { - return ctx_file_exists_; - } - model_description_ = model_description; +Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, + std::string& model_name, + std::string& model_description, + std::string& graph_partition_name, + std::string& cache_source, + const logging::Logger& logger) { + using namespace onnxruntime; + std::shared_ptr model; + ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger)); + const auto& graph = GraphViewer(model->MainGraph()); + const auto& node = graph.Nodes().begin(); + NodeAttrHelper node_helper(*node); + model_name = graph.Name(); + model_description = graph.Description(); + graph_partition_name = node_helper.Get(PARTITION_NAME, ""); + cache_source = node_helper.Get(SOURCE, ""); + + return Status::OK(); +} + +bool IsContextCacheFileExists(const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path) { // Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default - if (customer_context_cache_path.empty()) { - context_cache_path_ = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx"; - } else { - context_cache_path_ = customer_context_cache_path; + if (!customer_context_cache_path.empty()) { + context_cache_path = ToPathString(customer_context_cache_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_qnn_ctx.onnx"); } - ctx_file_exists_ = std::filesystem::is_regular_file(context_cache_path_) && std::filesystem::exists(context_cache_path_); - - return ctx_file_exists_; + return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); } -Status QnnCacheModelHandler::ValidateWithContextFile(const std::string& model_name, - const std::string& graph_partition_name, - const logging::Logger& logger) { - ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!"); - +Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, + const std::string& model_name, + const std::string& model_description, + const std::string& graph_partition_name, + const logging::Logger& logger) { std::string model_name_from_ctx_cache; std::string model_description_from_ctx_cache; std::string graph_partition_name_from_ctx_cache; std::string cache_source; - ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path_, - model_name_from_ctx_cache, - model_description_from_ctx_cache, - graph_partition_name_from_ctx_cache, - cache_source, - logger)); + auto status = GetMetadataFromEpContextModel(context_cache_path, + model_name_from_ctx_cache, + model_description_from_ctx_cache, + graph_partition_name_from_ctx_cache, + cache_source, + logger); + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel."); + } // The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT if (cache_source != kQnnExecutionProvider) { + LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort."; return Status::OK(); } - ORT_RETURN_IF(model_name != model_name_from_ctx_cache, - "Model file name from context cache metadata: " + model_name_from_ctx_cache + - " is different with target: " + model_name + - ". Please make sure the context binary file matches the model."); - - ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache, - "Model description from context cache metadata: " + model_description_from_ctx_cache + - " is different with target: " + model_description_ + - ". Please make sure the context binary file matches the model."); - - ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache && get_capability_round_2_, - "Graph name from context cache metadata: " + graph_partition_name_from_ctx_cache + - " is different with target: " + graph_partition_name + - ". You may need to re-generate the context binary file."); + if (model_name != model_name_from_ctx_cache || + model_description != model_description_from_ctx_cache || + graph_partition_name != graph_partition_name_from_ctx_cache) { + std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ", + model_name, " ", model_description, " ", graph_partition_name, + " vs epcontext: ", + model_name_from_ctx_cache, " ", + model_description_from_ctx_cache, " ", + graph_partition_name_from_ctx_cache); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message); + } - get_capability_round_2_ = true; return Status::OK(); } -Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const logging::Logger& logger) { +Status GenerateCtxCacheOnnxModel(const std::string model_name, + const std::string model_description, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const onnxruntime::PathString& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger) { std::unordered_map domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}}; - Model model(model_name_, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, logger); auto& graph = model.MainGraph(); - graph.SetDescription(model_description_); + graph.SetDescription(model_description); using namespace ONNX_NAMESPACE; int index = 0; // Still need more work to support multiple partition, it's out of EP's scope. // Already have code to make sure it's single partition before this method get invoked. for (const auto& fused_node_graph : fused_nodes_and_graphs) { - const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); Node& fused_node = fused_node_graph.fused_node; - // graph_viewer.Name() is generated in GetCapability, e.g QNN_[hash_id]_[id] - // dump graph_viewer.Name() as metadata in context cache binary file, so that we can validate it in GetCapability auto qnn_model_kv = qnn_models.find(fused_node.Name()); ORT_RETURN_IF(qnn_model_kv == qnn_models.end(), fused_node.Name(), " not exist in QnnModel table."); @@ -229,7 +237,7 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, ORT_RETURN_IF_ERROR(CreateNodeArgs(qnn_model->GetInputNames(), qnn_model->GetInputsInfo(), inputs, graph)); ORT_RETURN_IF_ERROR(CreateNodeArgs(qnn_model->GetOutputNames(), qnn_model->GetOutputsInfo(), outputs, graph)); - const std::string& graph_name = graph_viewer.Name(); + const std::string& graph_name = fused_node.Name(); auto& ep_node = graph.AddNode(graph_name, EPCONTEXT_OP, "Onnx Qnn context binary cache for graph partition: " + graph_name, @@ -240,13 +248,13 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, // Only dump the context buffer once since all QNN graph are in one single context if (0 == index) { - if (qnn_context_embed_mode_) { + if (qnn_context_embed_mode) { std::string cache_payload(buffer, buffer + buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload); } else { - std::string context_cache_path(context_cache_path_ + "_" + graph_name + ".bin"); - std::string context_cache_name(std::filesystem::path(context_cache_path).filename().string()); - std::ofstream of_stream(context_cache_path.c_str(), std::ofstream::binary); + onnxruntime::PathString context_bin_path = context_cache_path + ToPathString("_" + graph_name + ".bin"); + std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string()); + std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary); if (!of_stream) { LOGS(logger, ERROR) << "Failed to open create context file."; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file."); @@ -257,7 +265,7 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, } else { ep_node.AddAttribute(MAIN_CONTEXT, static_cast(0)); } - int64_t embed_mode = qnn_context_embed_mode_ ? static_cast(1) : static_cast(0); + int64_t embed_mode = qnn_context_embed_mode ? static_cast(1) : static_cast(0); ep_node.AddAttribute(EMBED_MODE, embed_mode); ep_node.AddAttribute(EP_SDK_VER, sdk_build_version); ep_node.AddAttribute(PARTITION_NAME, graph_name); @@ -265,7 +273,7 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer, ++index; } ORT_RETURN_IF_ERROR(graph.Resolve()); - ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path_)); + ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path)); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index e9ca87a679ecc..0011d0f43f5bc 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -38,77 +38,50 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); -class QnnCacheModelHandler { - public: - QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) { - } - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler); - - Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger) { - if (is_qnn_ctx_model) { - return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); - } else if (is_ctx_cache_file_exist) { - return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); - } - return Status::OK(); - } - - bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const std::string& model_description, - const onnxruntime::PathString& model_pathstring); - - bool GetIsContextCacheFileExists() const { - return ctx_file_exists_; - } - - Status ValidateWithContextFile(const std::string& model_name, - const std::string& graph_name, - const logging::Logger& logger); - - Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::string& graph_partition_name, - std::string& cache_source, - const logging::Logger& logger); - - Status GenerateCtxCacheOnnxModel(unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const logging::Logger& logger); - - private: - Status GetEpContextFromModel(const std::string& ctx_onnx_model_path, +bool IsContextCacheFileExists(const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path); + +Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model, + const logging::Logger& logger); + +Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + QnnModel& qnn_model); + +Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, + const onnxruntime::PathString& ctx_onnx_model_path, + bool is_qnn_ctx_model, + bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, QnnModel& qnn_model, const logging::Logger& logger); - Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const std::string& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model); - - private: - bool is_metadata_ready_ = false; - // model_name_ to cache_source_ -- metadata get from generated Qnn context binary Onnx model - std::string model_name_ = ""; - std::string model_description_ = ""; - std::string graph_partition_name_ = ""; - std::string cache_source_ = ""; - - std::string context_cache_path_ = ""; - bool ctx_file_exists_ = false; - bool get_capability_round_2_ = false; - bool qnn_context_embed_mode_ = true; -}; // QnnCacheModelHandler +Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, + const std::string& model_name, + const std::string& model_description, + const std::string& graph_partition_name, + const logging::Logger& logger); +Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, + std::string& model_name, + std::string& model_description, + std::string& graph_partition_name, + std::string& cache_source, + const logging::Logger& logger); + +Status GenerateCtxCacheOnnxModel(const std::string model_name, + const std::string model_description, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const onnxruntime::PathString& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index d5c3e4619f263..f1a5d41a8a6ff 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -63,6 +63,8 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("SpaceToDepth", *this); CreateSimpleOpBuilder("GridSample", *this); + + CreateSimpleOpBuilder("LpNormalization", *this); } { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index c979e599f96c4..4eb599eb50175 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -161,6 +161,7 @@ class BaseOpBuilder : public IOpBuilder { {"Tanh", QNN_OP_TANH}, {"Transpose", QNN_OP_TRANSPOSE}, {"GridSample", QNN_OP_GRID_SAMPLE}, + {"LpNormalization", QNN_OP_L2_NORM}, {"DequantizeLinear", QNN_OP_DEQUANTIZE}, {"QuantizeLinear", QNN_OP_QUANTIZE}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index fdc5317419c5b..dd678ab5467ed 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -335,6 +335,19 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); } + if (op_type == "LpNormalization") { + int32_t default_axis = -1; + Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; + ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis)); + QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_L2_NORM_PARAM_AXIS, axis_qnn_scalar); + param_tensor_names.push_back(axis_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); + + NodeAttrHelper node_helper(node_unit); + int64_t norm_p_order = node_helper.Get("p", static_cast(2)); + ORT_RETURN_IF(norm_p_order != 2, "QNN EP only supports LpNormalization with 'p' attribute equal to 2."); + } + if (op_type == "MatMul") { Qnn_Scalar_t scalar_param = QNN_SCALAR_INIT; scalar_param.dataType = QNN_DATATYPE_BOOL_8; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index aac82c89d6f49..4edccea661642 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -22,7 +22,6 @@ namespace onnxruntime { namespace qnn { class QnnModel; -class QnnCacheModelHandler; class QnnBackendManager { public: diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 8acd0d68b71d0..c7b309ae471c9 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -16,20 +16,12 @@ #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" +#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" namespace onnxruntime { constexpr const char* QNN = "QNN"; -std::string GetFileNameFromModelPath(onnxruntime::Path model_path) { - auto model_path_components = model_path.GetComponents(); - // There's no model path if model loaded from buffer stead of file - if (model_path_components.empty()) { - return ""; - } - return PathToUTF8String(model_path_components.back()); -} - void QNNExecutionProvider::ParseProfilingLevel(std::string profiling_level_string) { std::transform(profiling_level_string.begin(), profiling_level_string.end(), @@ -134,16 +126,15 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio static const std::string CONTEXT_CACHE_PATH = "qnn_context_cache_path"; auto context_cache_path_pos = provider_options_map.find(CONTEXT_CACHE_PATH); if (context_cache_path_pos != provider_options_map.end()) { - context_cache_path_ = context_cache_path_pos->second; - LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_; + context_cache_path_cfg_ = context_cache_path_pos->second; + LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; } - bool qnn_context_embed_mode = true; static const std::string CONTEXT_CACHE_EMBED_MODE = "qnn_context_embed_mode"; auto context_cache_embed_mode_pos = provider_options_map.find(CONTEXT_CACHE_EMBED_MODE); if (context_cache_embed_mode_pos != provider_options_map.end()) { - qnn_context_embed_mode = context_cache_embed_mode_pos->second == "1"; - LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode; + qnn_context_embed_mode_ = context_cache_embed_mode_pos->second == "1"; + LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode_; } static const std::string BACKEND_PATH = "backend_path"; @@ -206,7 +197,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio htp_performance_mode_, context_priority_, std::move(qnn_saver_path)); - qnn_cache_model_handler_ = std::make_unique(qnn_context_embed_mode); } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -343,9 +333,10 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // This is for case: QDQ model + Onnx Qnn context cache model if (context_cache_enabled_ && !is_qnn_ctx_model) { - load_from_cached_context = qnn_cache_model_handler_->IsContextCacheFileExists(context_cache_path_, - graph_viewer.Description(), - graph_viewer.ModelPath().ToPathString()); + onnxruntime::PathString context_cache_path; + load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_, + graph_viewer.ModelPath().ToPathString(), + context_cache_path); } // Load from cached context will load the QnnSystem lib and skip the Qnn context creation @@ -444,17 +435,6 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } const size_t num_of_partitions = result.size(); - - if (!is_qnn_ctx_model && load_from_cached_context && 1 == num_of_partitions) { - rt = qnn_cache_model_handler_->ValidateWithContextFile(GetFileNameFromModelPath(graph_viewer.ModelPath()), - result[0]->sub_graph->GetMetaDef()->name, - logger); - if (Status::OK() != rt) { - LOGS(logger, ERROR) << "QNN failed to validate context cache metadata: " << rt.ErrorMessage(); - return result; - } - } - const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions, ", number of nodes in the graph: ", num_nodes_in_graph, ", number of nodes supported by QNN: ", num_of_supported_nodes); @@ -547,25 +527,38 @@ Status QNNExecutionProvider::Compile(const std::vector& fused bool is_qnn_ctx_model = false; ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); - bool is_ctx_file_exist = qnn_cache_model_handler_->GetIsContextCacheFileExists(); + onnxruntime::PathString context_cache_path; + bool is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, + graph_viewer.ModelPath().ToPathString(), + context_cache_path); + const std::string& model_name = graph_viewer.GetGraph().Name(); + const std::string& model_description = graph_viewer.GetGraph().Description(); + const std::string& graph_meta_id = fused_node.Name(); + if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) { + ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, + model_name, + model_description, + graph_meta_id, + logger)); + } + if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); // Load and execute from cached context if exist - ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->LoadQnnCtxFromOnnxModel(graph_viewer, - context_cache_path_, - is_qnn_ctx_model, - is_ctx_file_exist, - qnn_backend_manager_.get(), - *(qnn_model.get()), - logger)); + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer, + context_cache_path, + is_qnn_ctx_model, + is_ctx_file_exist, + qnn_backend_manager_.get(), + *(qnn_model.get()), + logger)); ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] // the name here should be same with context->node_name in compute_info - LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name(); - qnn_models_.emplace(fused_node.Name(), std::move(qnn_model)); + qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); return Status::OK(); @@ -576,12 +569,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); - ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GenerateCtxCacheOnnxModel(context_buffer.get(), - buffer_size, - qnn_backend_manager_->GetSdkVersion(), - fused_nodes_and_graphs, - qnn_models_, - logger)); + ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name, + model_description, + context_buffer.get(), + buffer_size, + qnn_backend_manager_->GetSdkVersion(), + fused_nodes_and_graphs, + qnn_models_, + context_cache_path, + qnn_context_embed_mode_, + logger)); } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index cf0bff8890d0c..8c99a916a6f69 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -8,7 +8,6 @@ #include #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" -#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/qnn_graph_configs_helper.h" namespace onnxruntime { @@ -71,10 +70,10 @@ class QNNExecutionProvider : public IExecutionProvider { std::unordered_map> qnn_models_; uint32_t rpc_control_latency_ = 0; bool context_cache_enabled_ = false; - std::string context_cache_path_ = ""; + std::string context_cache_path_cfg_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. - std::unique_ptr qnn_cache_model_handler_; qnn::ContextPriority context_priority_ = qnn::ContextPriority::NORMAL; + bool qnn_context_embed_mode_ = true; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_triton.cuh b/onnxruntime/core/providers/rocm/math/softmax_triton.cuh index 737e396855e35..cc0e0d70056cc 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_triton.cuh +++ b/onnxruntime/core/providers/rocm/math/softmax_triton.cuh @@ -60,7 +60,7 @@ auto GetSoftmaxTritonOps() { } args = {(void*)params->output, (const void*)params->input, params->input_stride, params->output_stride, params->softmax_elements}; // grid dim is (batch_count, 1, 1) - return LaunchTritonKernel(params->stream, i, params->batch_count, 1, 1, &args, sizeof(args)); + return LaunchTritonKernel(params->StreamHandle(), i, params->batch_count, 1, 1, &args, sizeof(args)); }; ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); } diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index b9c0cdcc1c341..776dabd757af4 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -26,6 +26,10 @@ using onnxruntime::contrib::rocm::blas::GemmFastGeluParams; #ifdef USE_HIPBLASLT +// For large K and small M/N, K dim will be split to multiple workgroups and buffers, +// which will require additional workspace. Here we set the max workspace size to 32MB. +constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024; + enum ActivationType { NONE = 0, RELU = 1, @@ -225,6 +229,9 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp IAllocatorUniquePtr workspace_buffer; if (workspace_size > 0) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(workspace_size > kHipBlasLtMaxWorkSpaceSizeInBytes, + "Workspace size exceeds limit (32M): ", workspace_size); + workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes; workspace_buffer = params->tuning_ctx->GetScratchBuffer(workspace_size, params->stream); } diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 6b1207d3d16f0..39ea4dd8412bb 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -166,6 +166,12 @@ std::vector NodeAttrHelper::Get(const std::string& key, const std::vector return std::vector{source.cbegin(), source.cend()}; } +std::optional NodeAttrHelper::GetInt(const std::string& key) const { + if (!HasAttr(key)) + return std::nullopt; + return node_attributes_.at(key).i(); +} + bool NodeAttrHelper::HasAttr(const std::string& key) const { return Contains(node_attributes_, key); } diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index db07938c1897e..1e93f040711df 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -6,6 +6,7 @@ #include #include #include +#include #include "core/graph/basic_types.h" @@ -57,6 +58,8 @@ class NodeAttrHelper { uint32_t Get(const std::string& key, uint32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; + std::optional GetInt(const std::string& key) const; + bool HasAttr(const std::string& key) const; private: diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 38266f566e6e1..d34cb7e362446 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -85,7 +85,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const auto* node(graph_viewer.GetNode(node_idx)); bool supported = false; // Firstly check if platform supports the WebNN op. - if (CheckSingleOp(node->OpType(), wnn_builder_)) { + if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; supported = IsNodeSupported(*node, graph_viewer, device_type, logger); } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 8ae16f0dd21fc..28b54b9c9cf8d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -30,6 +30,11 @@ enum class WebnnDeviceType { GPU, }; +typedef struct { + std::string opName; + bool isCpuSupported; // The WebNN CPU backend XNNPack supports it (not about the CPU EP). +} WebnnOpInfo; + bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); template @@ -128,90 +133,107 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const emscripten::val& wnn_builder_, const WebnnDeviceType device_type, const logging::Logger& logger); -static const InlinedHashMap op_map = { - {"Abs", "abs"}, - {"Add", "add"}, - {"ArgMax", "argMax"}, - {"ArgMin", "argMin"}, - {"AveragePool", "averagePool2d"}, - {"BatchNormalization", "meanVarianceNormalization"}, - {"Cast", "cast"}, - {"Ceil", "ceil"}, - {"Clip", "clamp"}, - {"Concat", "concat"}, - {"Conv", "conv2d"}, - {"ConvTranspose", "convTranspose2d"}, - {"Cos", "cos"}, - {"Div", "div"}, - {"Elu", "elu"}, - {"Equal", "equal"}, - {"Erf", "erf"}, - {"Exp", "exp"}, - {"Expand", "expand"}, - {"Flatten", "flattenTo2d"}, - {"Floor", "floor"}, - {"Gather", "gather"}, - {"Gemm", "gemm"}, - {"GlobalAveragePool", "averagePool2d"}, - {"GlobalMaxPool", "maxPool2d"}, - {"GlobalLpPool", "l2Pool2d"}, - {"Greater", "greater"}, - {"GreaterOrEqual", "greaterOrEqual"}, - {"GroupNormalization", "meanVarianceNormalization"}, - {"HardSigmoid", "hardSigmoid"}, - {"HardSwish", "hardSwish"}, - {"Identity", "identity"}, - {"InstanceNormalization", "meanVarianceNormalization"}, - {"LayerNormalization", "meanVarianceNormalization"}, - {"LeakyRelu", "leakyRelu"}, - {"Less", "lesser"}, - {"LessOrEqual", "lesserOrEqual"}, - {"Log", "log"}, - {"LpPool", "l2Pool2d"}, - {"MatMul", "matmul"}, - {"Max", "max"}, - {"MaxPool", "maxPool2d"}, - {"Min", "min"}, - {"Mul", "mul"}, - {"Neg", "neg"}, - {"Not", "logicalNot"}, - {"Pad", "pad"}, - {"Pow", "pow"}, - {"PRelu", "prelu"}, - {"Reciprocal", "reciprocal"}, - {"ReduceL1", "reduceL1"}, - {"ReduceL2", "reduceL2"}, - {"ReduceLogSum", "reduceLogSum"}, - {"ReduceLogSumExp", "reduceLogSumExp"}, - {"ReduceMax", "reduceMax"}, - {"ReduceMean", "reduceMean"}, - {"ReduceMin", "reduceMin"}, - {"ReduceProd", "reduceProduct"}, - {"ReduceSum", "reduceSum"}, - {"ReduceSumSquare", "reduceSumSquare"}, - {"Relu", "relu"}, - {"Reshape", "reshape"}, - {"Resize", "resample2d"}, - {"Shape", "slice"}, - {"Sigmoid", "sigmoid"}, - {"Softplus", "softplus"}, - {"Softsign", "softsign"}, - {"Sin", "sin"}, - {"Slice", "slice"}, - {"Softmax", "softmax"}, - {"Split", "split"}, - {"Sqrt", "sqrt"}, - {"Squeeze", "squeeze"}, - {"Sub", "sub"}, - {"Tan", "tan"}, - {"Tanh", "tanh"}, - {"Transpose", "transpose"}, - {"Unsqueeze", "unsqueeze"}, - {"Where", "elementwiseIf"}, +static const InlinedHashMap op_map = { + {"Abs", {"abs", true}}, + {"Add", {"add", true}}, + {"ArgMax", {"argMax", false}}, + {"ArgMin", {"argMin", false}}, + {"AveragePool", {"averagePool2d", true}}, + {"BatchNormalization", {"meanVarianceNormalization", false}}, + {"Cast", {"cast", false}}, + {"Ceil", {"ceil", true}}, + {"Clip", {"clamp", true}}, + {"Concat", {"concat", true}}, + {"Conv", {"conv2d", true}}, + {"ConvTranspose", {"convTranspose2d", true}}, + {"Cos", {"cos", false}}, + {"Div", {"div", true}}, + {"Elu", {"elu", true}}, + {"Equal", {"equal", false}}, + {"Erf", {"erf", false}}, + {"Exp", {"exp", false}}, + {"Expand", {"expand", false}}, + {"Flatten", {"flattenTo2d", false}}, + {"Floor", {"floor", true}}, + {"Gather", {"gather", false}}, + {"Gemm", {"gemm", true}}, + {"GlobalAveragePool", {"averagePool2d", true}}, + {"GlobalMaxPool", {"maxPool2d", true}}, + {"GlobalLpPool", {"l2Pool2d", false}}, + {"Greater", {"greater", false}}, + {"GreaterOrEqual", {"greaterOrEqual", false}}, + {"GroupNormalization", {"meanVarianceNormalization", false}}, + {"HardSigmoid", {"hardSigmoid", false}}, + {"HardSwish", {"hardSwish", true}}, + {"Identity", {"identity", false}}, + {"InstanceNormalization", {"meanVarianceNormalization", false}}, + {"LayerNormalization", {"meanVarianceNormalization", false}}, + {"LeakyRelu", {"leakyRelu", true}}, + {"Less", {"lesser", false}}, + {"LessOrEqual", {"lesserOrEqual", false}}, + {"Log", {"log", false}}, + {"LpPool", {"l2Pool2d", false}}, + {"MatMul", {"matmul", false}}, + {"Max", {"max", true}}, + {"MaxPool", {"maxPool2d", true}}, + {"Min", {"min", true}}, + {"Mul", {"mul", true}}, + {"Neg", {"neg", true}}, + {"Not", {"logicalNot", false}}, + {"Pad", {"pad", true}}, + {"Pow", {"pow", true}}, + {"PRelu", {"prelu", true}}, + {"Reciprocal", {"reciprocal", false}}, + {"ReduceL1", {"reduceL1", false}}, + {"ReduceL2", {"reduceL2", false}}, + {"ReduceLogSum", {"reduceLogSum", false}}, + {"ReduceLogSumExp", {"reduceLogSumExp", false}}, + {"ReduceMax", {"reduceMax", false}}, + {"ReduceMean", {"reduceMean", true}}, + {"ReduceMin", {"reduceMin", false}}, + {"ReduceProd", {"reduceProduct", false}}, + {"ReduceSum", {"reduceSum", true}}, + {"ReduceSumSquare", {"reduceSumSquare", false}}, + {"Relu", {"relu", true}}, + {"Reshape", {"reshape", true}}, + {"Resize", {"resample2d", true}}, + {"Shape", {"slice", true}}, + {"Sigmoid", {"sigmoid", true}}, + {"Softplus", {"softplus", false}}, + {"Softsign", {"softsign", false}}, + {"Sin", {"sin", false}}, + {"Slice", {"slice", true}}, + {"Softmax", {"softmax", true}}, + {"Split", {"split", true}}, + {"Sqrt", {"sqrt", false}}, + {"Squeeze", {"squeeze", false}}, + {"Sub", {"sub", true}}, + {"Tan", {"tan", false}}, + {"Tanh", {"tanh", true}}, + {"Transpose", {"transpose", true}}, + {"Unsqueeze", {"unsqueeze", false}}, + {"Where", {"elementwiseIf", false}}, }; -inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_) { - return op_map.find(op_type) != op_map.end() && wnn_builder_[op_map.find(op_type)->second].as(); +inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_, + const WebnnDeviceType device_type) { + // Returns false if the op_type is not listed in the op_map. + if (op_map.find(op_type) == op_map.end()) { + return false; + } + // Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser. + if (!wnn_builder_[op_map.find(op_type)->second.opName].as()) { + return false; + } + // The current WebNN CPU (XNNPack) backend supports a limited op list, and we'd rather + // fall back early to the ORT CPU EP rather than fail in the WebNN "cpu" deviceType. + // This is a workaround because the op may be included in MLGraphBuilder for DirectML + // backend but without XNNPack implementation in Chromium. + if (!op_map.find(op_type)->second.isCpuSupported) { + return false; + } + + return true; } constexpr std::array supported_cpu_data_types = { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f02d180ab104f..75be72658f98f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -74,6 +74,7 @@ #ifdef ENABLE_TRAINING #include "core/framework/partial_graph_execution_state.h" #include "core/framework/stream_execution_context.h" +#include "orttraining/core/optimizer/memory_optimizer.h" #endif using namespace ONNX_NAMESPACE; @@ -1149,6 +1150,20 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(copy_transformer, *session_logger_, graph)); } +#ifdef ENABLE_TRAINING + // Enable memory optimizations (mainly insert recomputation nodes with priority). + // Only applicable for training scenarios. + { + const std::string memory_optimizer_config = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerEnabler, ""); + const std::string probe_level = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsMemoryOptimizerProbeLevel, "0"); + + MemoryOptimizer mem_transformer{memory_optimizer_config, probe_level}; + ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(mem_transformer, *session_logger_, graph)); + } +#endif + return Status::OK(); } #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h new file mode 100644 index 0000000000000..a0405e32034ae --- /dev/null +++ b/onnxruntime/core/util/matrix_layout.h @@ -0,0 +1,475 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * matrix_layout.h + * + * Abstract: + * Utils for simplifying positioning and striding in tensors. Inspired + * by CUTLASS, striving for 0 runtime cost while promote safety. + * + * Only supports 2D tensors (matrix) for now. + */ + +#pragma once + +#include +#include "core/common/gsl.h" + +// TODO!! Already have this in cuda, what about cpu code though? +#if defined(_MSC_VER) +#define ORT_FORCEINLINE __forceinline +#else +#define ORT_FORCEINLINE __attribute__((always_inline)) inline +#endif + +namespace onnxruntime { + +// +// Clang-format doesn't handle force inline decorator well, it insists on +// adding extra indentation to the next line, making it very confusing +// to read. So we turn it off for this file. +// clang-format off +// + +/** + * @brief A tuple of integers to represent tensor coordinates + */ +template < + int Rank_, ///< Logical rank of coordinate + typename Index_ = int, ///< Index type used for each dimension + typename LongIndex_ = int64_t ///< Long index type used for linear offsets + > +struct Position { + public: + /// Number of elements in Position + static int const kRank = Rank_; + + /// Index type used to store elements + using Index = Index_; + + /// Type used to represent linear offsets + using LongIndex = LongIndex_; + + private: + Index idx[kRank]; + + public: + ORT_FORCEINLINE explicit Position(Index value = Index(0)) { + for (int i = 0; i < kRank; ++i) { + idx[i] = value; + } + } + + /// Constructs from an array of integers + ORT_FORCEINLINE + Position(Index const (&_idx)[kRank]) { + for (int i = 0; i < kRank; ++i) { + idx[i] = _idx[i]; + } + } + + template + ORT_FORCEINLINE + Position(Position other) { + for (int i = 0; i < kRank; ++i) { + idx[i] = other[i]; + } + } + + ORT_FORCEINLINE + Position operator+(Position const& b) const { + Position c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] + b.idx[i]; + } + return c; + } + + ORT_FORCEINLINE + Position operator-(Position const& b) const { + Position c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] - b.idx[i]; + } + return c; + } + + ORT_FORCEINLINE + Position operator*(Position const& b) const { + Position c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] * b.idx[i]; + } + return c; + } + + ORT_FORCEINLINE + Position operator/(Position const& b) const { + Position c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] / b.idx[i]; + } + return c; + } + + ORT_FORCEINLINE + Position& operator+=(Position const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] += b.idx[i]; + } + return *this; + } + + ORT_FORCEINLINE + Position& operator-=(Position const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] -= b.idx[i]; + } + return *this; + } + + ORT_FORCEINLINE + Position& operator*=(Position const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] *= b.idx[i]; + } + return *this; + } + + ORT_FORCEINLINE + Position& operator/=(Position const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] /= b.idx[i]; + } + return *this; + } + + ORT_FORCEINLINE Index& operator[](int dim) { return idx[dim]; } + + ORT_FORCEINLINE Index const& operator[](int dim) const { return idx[dim]; } + + ORT_FORCEINLINE bool operator==(Position const& b) const { + bool equal = true; + for (int i = 0; equal && i < kRank; ++i) { + equal = (idx[i] == b.idx[i]); + } + return equal; + } + + ORT_FORCEINLINE bool operator!=(Position const& b) const { return !(*this == b); } + + ORT_FORCEINLINE + Position& clamp(Position const& max, Position const& min = Position()) { + for (int i = 0; i < kRank; ++i) { + idx[i] = std::max(std::min(idx[i], max.idx[i]), min.idx[i]); + } + return *this; + } + + ORT_FORCEINLINE + Index sum() const { + Index sum_(idx[0]); + for (int i = 1; i < kRank; ++i) { + sum_ += idx[i]; + } + return sum_; + } + + ORT_FORCEINLINE + LongIndex product() const { + LongIndex product_(idx[0]); + for (int i = 1; i < kRank; ++i) { + product_ *= idx[i]; + } + return product_; + } +}; + +template +Position<2, T, L> make_Position(T _0, T _1) { + T values[2] = {_0, _1}; + return Position<2, T, L>(values); +} + +template +Position<3, T, L> make_Position(T _0, T _1, T _2) { + T values[3] = {_0, _1, _2}; + return Position<2, T, L>(values); +} + +/// Describes the size of a matrix tile +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct MatrixShape { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix + + ORT_FORCEINLINE static Position<2> toCoord() { + return make_Position(kRow, kColumn); + } +}; + +/** + * @brief Defines a mapping from logical coordinate to linear memory + * offsets in a row major layout matrix + */ +class RowMajorLayout { + public: + /// Index type used for coordinates + using Index = int; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using MatCoord = Position<2, Index, LongIndex>; + + private: + Index stride_; + + public: + ORT_FORCEINLINE + RowMajorLayout(Index ldm = 0) : stride_(ldm) {} + + ORT_FORCEINLINE static RowMajorLayout packed(MatCoord const& extent) { + return RowMajorLayout(extent[1]); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ORT_FORCEINLINE + LongIndex operator()(MatCoord const& coord) const { + return LongIndex(coord[0]) * stride_ + coord[1]; + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + ORT_FORCEINLINE + MatCoord inverse(LongIndex offset) const { + return make_Position(Index(offset / stride_), Index(offset % stride_)); + } + + ORT_FORCEINLINE + Index stride() const { + return stride_; + } +}; + +class ColumnMajorLayout { + public: + /// Index type used for coordinates + using Index = int; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using MatCoord = Position<2, Index, LongIndex>; + + private: + Index stride_; + + public: + ORT_FORCEINLINE + ColumnMajorLayout(Index ldm = 0) : stride_(ldm) {} + + ORT_FORCEINLINE static ColumnMajorLayout packed(MatCoord const& extent) { + return ColumnMajorLayout(extent[0]); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ORT_FORCEINLINE + LongIndex operator()(MatCoord const& coord) const { + return LongIndex(coord[1]) * LongIndex(stride_) + coord[0]; + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + ORT_FORCEINLINE + MatCoord inverse(LongIndex offset) const { + return make_Position(Index(offset % stride_), Index(offset / stride_)); + } + + ORT_FORCEINLINE + Index stride() const { + return stride_; + } +}; + +/** + * @brief A reference to a tensor, with a layout object to map logical + * coordinates to linear offsets. + */ +template < + /// Data type of element stored within tensor, must be numerical types + typename Element_, + /// Defines a mapping from logical coordinate to linear memory offsets + typename Layout_, + /// If true, extra bounds checking is performed on all accesses + bool ExtraBoundsCheck_ = false> +class MatrixRef { + public: + /// Data type of individual access + using Element = Element_; + + using Reference = Element&; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using MatCoord = typename Layout::MatCoord; + + /// MatrixRef to constant data + using ConstMatrixRef = MatrixRef< + typename std::remove_const::type const, + Layout, ExtraBoundsCheck_>; + + /// MatrixRef to non-constant data + using NonConstMatrixRef = MatrixRef< + typename std::remove_const::type, + Layout, ExtraBoundsCheck_>; + + static constexpr bool IsNonConstRef = std::is_same>::value; + + private: + /// Pointer to data + gsl::span data_; + + /// Shape of matrix + MatCoord shape_; + + /// Layout object maps logical coordinates to linear offsets + Layout layout_; + + public: + ORT_FORCEINLINE + MatrixRef() : data_() {} + + ORT_FORCEINLINE + MatrixRef( + gsl::span const& data, ///< pointer to start of tensor + MatCoord const& shape ///< shape of tensor + ) : data_(data), shape_(shape), layout_(Layout::packed(shape)) { + Expects(data_.size() >= size_t(shape_.product())); + } + + ORT_FORCEINLINE + MatrixRef( + Element* ptr, ///< pointer to start of tensor + LongIndex size, ///< size of tensor in elements + MatCoord const& shape ///< shape of tensor + ) : data_(ptr, size), shape_(shape), layout_(Layout::packed(shape)) { + Expects(data_.size() >= shape_.product()); + } + + /// Converting constructor from MatrixRef to non-constant data. + template + ORT_FORCEINLINE + MatrixRef( + NonConstMatrixRef const& ref, ///< MatrixRef to non-const data + /// SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const + _Magic magic = (typename std::enable_if::type)0 + ) : data_(ref.data()), shape_(ref.shape()), layout_(Layout::packed(ref.shape())) {} + + ORT_FORCEINLINE + ConstMatrixRef const_ref() const { + return ConstMatrixRef(data_, shape_); + } + + ORT_FORCEINLINE + NonConstMatrixRef non_const_ref() { + return NonConstMatrixRef( + const_cast::type*>(data_.data()), + data_.size(), shape_); + } + + /// Returns true if the MatrixRef is non-null + ORT_FORCEINLINE + bool good() const { return !data_.empty(); } + + ORT_FORCEINLINE + gsl::span const& data() const { return data_; } + + ORT_FORCEINLINE + MatCoord const& shape() const { return shape_; } + + ORT_FORCEINLINE + Layout& layout() { return layout_; } + + ORT_FORCEINLINE + Layout layout() const { return layout_; } + + ORT_FORCEINLINE + Index stride() const { return layout_.stride(); } + + ORT_FORCEINLINE + Index& stride() { return layout_.stride(); } + + /// Computes the offset of an index from the origin of the tensor + ORT_FORCEINLINE + LongIndex offset(MatCoord const& coord) const { + if constexpr (ExtraBoundsCheck_) { + Expects(coord[0] >= 0 && coord[0] < shape_[0]); + Expects(coord[1] >= 0 && coord[1] < shape_[1]); + } + return layout_(coord); + } + + /// Returns a reference to the element at a given Coord + ORT_FORCEINLINE + Reference at(MatCoord const& coord) const { + return data_[offset(coord)]; + } + + ORT_FORCEINLINE + Reference at(int row, int col) const { + return data_[offset(make_Position(row, col))]; + } + + /// Returns a reference to the element at a given Coord + ORT_FORCEINLINE + Reference operator[](MatCoord const& coord) const { + return data_[offset(coord)]; + } +}; + +/// Constructs a MatrixRef, deducing types from arguments. +template < + typename Element, + typename Layout = RowMajorLayout, + bool ExtraBoundsCheck = false> +ORT_FORCEINLINE +MatrixRef +make_MatrixRef( + Element* ptr, + int64_t size, + typename Layout::MatCoord const& shape) { + return MatrixRef(ptr, size, shape); +} + +template < + typename Element, + typename Layout = RowMajorLayout, + bool ExtraBoundsCheck = false> +ORT_FORCEINLINE +MatrixRef +make_MatrixRef( + const gsl::span& span, + typename Layout::MatCoord const& shape) { + return MatrixRef(span, shape); +} + +// clang-format off + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/quantization/operators/softmax.py b/onnxruntime/python/tools/quantization/operators/softmax.py index cf943a1c1b8e1..32f45c26c424f 100644 --- a/onnxruntime/python/tools/quantization/operators/softmax.py +++ b/onnxruntime/python/tools/quantization/operators/softmax.py @@ -86,7 +86,6 @@ class QDQSoftmax(QDQOperatorBase): def quantize(self): super().quantize() output_name = self.node.output[0] - quant_overrides = self.quantizer.tensor_quant_overrides.get(output_name, {}) quant_type = self.quantizer.activation_qType diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 1ec1ca3ba0c83..54af8844d0c6c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -83,6 +83,9 @@ For example: If you do not provide prompt, the script will generate different image sizes for a list of prompts for demonstration. +#### Generate an image with SDXL LCM guided by a text prompt +```python3 demo_txt2img_xl.py --lcm --disable-refiner "an astronaut riding a rainbow unicorn, cinematic, dramatic"``` + ## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline). diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index 4636f139d4613..b3056cc47c647 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -22,7 +22,7 @@ import coloredlogs from cuda import cudart -from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_type from pipeline_txt2img import Txt2ImgPipeline @@ -104,17 +104,25 @@ def run_inference(warmup=False): if not args.disable_cuda_graph: # inference once to get cuda graph - _image, _latency = run_inference(warmup=True) + _, _ = run_inference(warmup=True) print("[I] Warming up ..") for _ in range(args.num_warmup_runs): - _image, _latency = run_inference(warmup=True) + _, _ = run_inference(warmup=True) print("[I] Running StableDiffusion pipeline") if args.nvtx_profile: cudart.cudaProfilerStart() - _image, _latency = run_inference(warmup=False) + images, perf_data = run_inference(warmup=False) if args.nvtx_profile: cudart.cudaProfilerStop() + metadata = get_metadata(args, False) + metadata.update(pipeline.metadata()) + if perf_data: + metadata.update(perf_data) + metadata["images"] = len(images) + print(metadata) + pipeline.save_images(images, prompt, negative_prompt, metadata) + pipeline.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 4f9ecf6cbb152..7ff1794a68f8c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -22,7 +22,7 @@ import coloredlogs from cuda import cudart -from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_type from pipeline_img2img_xl import Img2ImgXLPipeline @@ -54,7 +54,11 @@ def load_pipelines(args, batch_size): # No VAE decoder in base when it outputs latent instead of image. base_info = PipelineInfo( - args.version, use_vae=args.disable_refiner, min_image_size=min_image_size, max_image_size=max_image_size + args.version, + use_vae=args.disable_refiner, + min_image_size=min_image_size, + max_image_size=max_image_size, + use_lcm=args.lcm, ) # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to @@ -118,7 +122,7 @@ def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False refiner.load_resources(image_height, image_width, batch_size) def run_base_and_refiner(warmup=False): - images, time_base = base.run( + images, base_perf = base.run( prompt, negative_prompt, image_height, @@ -130,24 +134,31 @@ def run_base_and_refiner(warmup=False): return_type="latent" if refiner else "image", ) if refiner is None: - return images, time_base + return images, base_perf # Use same seed in base and refiner. seed = base.get_current_seed() - images, time_refiner = refiner.run( + images, refiner_perf = refiner.run( prompt, negative_prompt, images, image_height, image_width, warmup=warmup, - denoising_steps=args.denoising_steps, - guidance=args.guidance, + denoising_steps=args.refiner_steps, + strength=args.strength, + guidance=args.refiner_guidance, seed=seed, ) - return images, time_base + time_refiner + perf_data = None + if base_perf and refiner_perf: + perf_data = {"latency": base_perf["latency"] + refiner_perf["latency"]} + perf_data.update({"base." + key: val for key, val in base_perf.items()}) + perf_data.update({"refiner." + key: val for key, val in refiner_perf.items()}) + + return images, perf_data if not args.disable_cuda_graph: # inference once to get cuda graph @@ -164,13 +175,24 @@ def run_base_and_refiner(warmup=False): print("[I] Running StableDiffusion XL pipeline") if args.nvtx_profile: cudart.cudaProfilerStart() - _, latency = run_base_and_refiner(warmup=False) + images, perf_data = run_base_and_refiner(warmup=False) if args.nvtx_profile: cudart.cudaProfilerStop() - print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("e2e", latency)) - print("|------------|--------------|") + if refiner: + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("e2e", perf_data["latency"])) + print("|------------|--------------|") + + metadata = get_metadata(args, True) + metadata.update({"base." + key: val for key, val in base.metadata().items()}) + if refiner: + metadata.update({"refiner." + key: val for key, val in refiner.metadata().items()}) + if perf_data: + metadata.update(perf_data) + metadata["images"] = len(images) + print(metadata) + (refiner or base).save_images(images, prompt, negative_prompt, metadata) def run_demo(args): @@ -189,6 +211,8 @@ def run_dynamic_shape_demo(args): """Run demo of generating images with different settings with ORT CUDA provider.""" args.engine = "ORT_CUDA" args.disable_cuda_graph = True + if args.lcm: + args.disable_refiner = True base, refiner = load_pipelines(args, 1) prompts = [ @@ -198,22 +222,31 @@ def run_dynamic_shape_demo(args): "cute grey cat with blue eyes, wearing a bowtie, acrylic painting", "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation", "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic", + "An astronaut riding a rainbow unicorn, cinematic, dramatic", + "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm", ] - # batch size, height, width, scheduler, steps, prompt, seed + # refiner, batch size, height, width, scheduler, steps, prompt, seed, guidance, refiner scheduler, refiner steps, refiner strength configs = [ - (1, 832, 1216, "UniPC", 8, prompts[0], None), - (1, 1024, 1024, "DDIM", 24, prompts[1], None), - (1, 1216, 832, "UniPC", 16, prompts[2], None), - (1, 1344, 768, "DDIM", 24, prompts[3], None), - (2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712), - (2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906), + (1, 832, 1216, "UniPC", 8, prompts[0], None, 5.0, "UniPC", 10, 0.3), + (1, 1024, 1024, "DDIM", 24, prompts[1], None, 5.0, "DDIM", 30, 0.3), + (1, 1216, 832, "UniPC", 16, prompts[2], None, 5.0, "UniPC", 10, 0.3), + (1, 1344, 768, "DDIM", 24, prompts[3], None, 5.0, "UniPC", 20, 0.3), + (2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712, 5.0, "UniPC", 10, 0.3), + (2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906, 5.0, "UniPC", 20, 0.3), ] + # In testing LCM, refiner is disabled so the settings of refiner is not used. + if args.lcm: + configs = [ + (1, 1024, 1024, "LCM", 8, prompts[6], None, 1.0, "UniPC", 20, 0.3), + (1, 1216, 832, "LCM", 6, prompts[7], 1337, 1.0, "UniPC", 20, 0.3), + ] + # Warm up each combination of (batch size, height, width) once before serving. args.prompt = ["warm up"] args.num_warmup_runs = 1 - for batch_size, height, width, _, _, _, _ in configs: + for batch_size, height, width, _, _, _, _, _, _, _, _ in configs: args.batch_size = batch_size args.height = height args.width = width @@ -223,7 +256,19 @@ def run_dynamic_shape_demo(args): # Run pipeline on a list of prompts. args.num_warmup_runs = 0 - for batch_size, height, width, scheduler, steps, example_prompt, seed in configs: + for ( + batch_size, + height, + width, + scheduler, + steps, + example_prompt, + seed, + guidance, + refiner_scheduler, + refiner_steps, + strength, + ) in configs: args.prompt = [example_prompt] args.batch_size = batch_size args.height = height @@ -231,12 +276,13 @@ def run_dynamic_shape_demo(args): args.scheduler = scheduler args.denoising_steps = steps args.seed = seed + args.guidance = guidance + args.refiner_scheduler = refiner_scheduler + args.refiner_steps = refiner_steps + args.strength = strength base.set_scheduler(scheduler) if refiner: - refiner.set_scheduler(scheduler) - print( - f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}" - ) + refiner.set_scheduler(refiner_scheduler) prompt, negative_prompt = repeat_prompt(args) run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 39ee273a3130d..70b4f34fdd988 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -21,6 +21,7 @@ # -------------------------------------------------------------------------- import argparse +from typing import Any, Dict import torch from diffusion_models import PipelineInfo @@ -68,8 +69,8 @@ def parse_arguments(is_xl: bool, description: str): "--scheduler", type=str, default="DDIM", - choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"], - help="Scheduler for diffusion process", + choices=["DDIM", "UniPC", "LCM"] if is_xl else ["DDIM", "EulerA", "UniPC"], + help="Scheduler for diffusion process" + " of base" if is_xl else "", ) parser.add_argument( @@ -105,6 +106,42 @@ def parse_arguments(is_xl: bool, description: str): help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", ) + if is_xl: + parser.add_argument( + "--lcm", + action="store_true", + help="Use fine-tuned latent consistency model to replace the UNet in base.", + ) + + parser.add_argument( + "--refiner-scheduler", + type=str, + default="DDIM", + choices=["DDIM", "UniPC"], + help="Scheduler for diffusion process of refiner.", + ) + + parser.add_argument( + "--refiner-guidance", + type=float, + default=5.0, + help="Guidance scale used in refiner.", + ) + + parser.add_argument( + "--refiner-steps", + type=int, + default=30, + help="Number of denoising steps in refiner. Note that actual refiner steps is refiner_steps * strength.", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.3, + help="A value between 0 and 1. The higher the value less the final image similar to the seed image.", + ) + # ONNX export parser.add_argument( "--onnx-opset", @@ -190,11 +227,52 @@ def parse_arguments(is_xl: bool, description: str): if args.onnx_opset is None: args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 + if is_xl: + if args.lcm: + if args.guidance > 1.0: + print("[I] Use --guidance=1.0 for base since LCM is used.") + args.guidance = 1.0 + if args.scheduler != "LCM": + print("[I] Use --scheduler=LCM for base since LCM is used.") + args.scheduler = "LCM" + if args.denoising_steps > 16: + print("[I] Use --denoising_steps=8 (no more than 16) for base since LCM is used.") + args.denoising_steps = 8 + assert args.strength > 0.0 and args.strength < 1.0 + print(args) return args +def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]: + metadata = { + "args.prompt": args.prompt, + "args.negative_prompt": args.negative_prompt, + "args.batch_size": args.batch_size, + "height": args.height, + "width": args.width, + "cuda_graph": not args.disable_cuda_graph, + "vae_slicing": args.enable_vae_slicing, + "engine": args.engine, + } + + if is_xl and not args.disable_refiner: + metadata["base.scheduler"] = args.scheduler + metadata["base.denoising_steps"] = args.denoising_steps + metadata["base.guidance"] = args.guidance + metadata["refiner.strength"] = args.strength + metadata["refiner.scheduler"] = args.refiner_scheduler + metadata["refiner.denoising_steps"] = args.refiner_steps + metadata["refiner.guidance"] = args.refiner_guidance + else: + metadata["scheduler"] = args.scheduler + metadata["denoising_steps"] = args.denoising_steps + metadata["guidance"] = args.guidance + + return metadata + + def repeat_prompt(args): if not isinstance(args.prompt, list): raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") @@ -223,7 +301,7 @@ def init_pipeline( # Initialize demo pipeline = pipeline_class( pipeline_info, - scheduler=args.scheduler, + scheduler=args.refiner_scheduler if pipeline_info.is_xl_refiner() else args.scheduler, output_dir=output_dir, hf_token=args.hf_token, verbose=False, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 514205d3b8945..8206bee753859 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -91,6 +91,7 @@ def __init__( min_image_size=256, max_image_size=1024, use_fp16_vae=True, + use_lcm=False, ): self.version = version self._is_inpaint = is_inpaint @@ -99,7 +100,9 @@ def __init__( self._min_image_size = min_image_size self._max_image_size = max_image_size self._use_fp16_vae = use_fp16_vae + self._use_lcm = use_lcm if is_refiner: + assert not use_lcm assert self.is_xl() def is_inpaint(self) -> bool: @@ -136,6 +139,9 @@ def custom_fp16_vae(self) -> Optional[str]: # For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None + def custom_unet(self) -> Optional[str]: + return "latent-consistency/lcm-sdxl" if self._use_lcm and self.is_xl_base() else None + @staticmethod def supported_versions(is_xl: bool): return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] @@ -730,8 +736,22 @@ def __init__( self.unet_dim = unet_dim self.time_dim = time_dim + self.custom_unet = pipeline_info.custom_unet() + self.do_classifier_free_guidance = not (self.custom_unet and "lcm" in self.custom_unet) + self.batch_multiplier = 2 if self.do_classifier_free_guidance else 1 + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + + if self.custom_unet: + model_dir = os.path.join(framework_model_dir, self.custom_unet, subfolder) + if not os.path.exists(model_dir): + unet = UNet2DConditionModel.from_pretrained(self.custom_unet, **options) + unet.save_pretrained(model_dir) + else: + unet = UNet2DConditionModel.from_pretrained(model_dir, **options) + return unet.to(self.device) + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) def get_input_names(self): @@ -741,12 +761,20 @@ def get_output_names(self): return ["latent"] def get_dynamic_axes(self): + if self.do_classifier_free_guidance: + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + } return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - "text_embeds": {0: "2B"}, - "time_ids": {0: "2B"}, + "sample": {0: "B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "B"}, + "latent": {0: "B", 2: "H", 3: "W"}, + "text_embeds": {0: "B"}, + "time_ids": {0: "B"}, } def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): @@ -763,49 +791,52 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, min_latent_width, max_latent_width, ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + m = self.batch_multiplier return { "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + (m * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (m * batch_size, self.unet_dim, latent_height, latent_width), + (m * max_batch, self.unet_dim, max_latent_height, max_latent_width), ], "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), + (m * min_batch, self.text_maxlen, self.embedding_dim), + (m * batch_size, self.text_maxlen, self.embedding_dim), + (m * max_batch, self.text_maxlen, self.embedding_dim), ], - "text_embeds": [(2 * min_batch, 1280), (2 * batch_size, 1280), (2 * max_batch, 1280)], + "text_embeds": [(m * min_batch, 1280), (m * batch_size, 1280), (m * max_batch, 1280)], "time_ids": [ - (2 * min_batch, self.time_dim), - (2 * batch_size, self.time_dim), - (2 * max_batch, self.time_dim), + (m * min_batch, self.time_dim), + (m * batch_size, self.time_dim), + (m * max_batch, self.time_dim), ], } def get_shape_dict(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + m = self.batch_multiplier return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "sample": (m * batch_size, self.unet_dim, latent_height, latent_width), "timestep": (1,), - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - "text_embeds": (2 * batch_size, 1280), - "time_ids": (2 * batch_size, self.time_dim), + "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (m * batch_size, 4, latent_height, latent_width), + "text_embeds": (m * batch_size, 1280), + "time_ids": (m * batch_size, self.time_dim), } def get_sample_input(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) dtype = torch.float16 if self.fp16 else torch.float32 + m = self.batch_multiplier return ( torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device ), torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), { "added_cond_kwargs": { - "text_embeds": torch.randn(2 * batch_size, 1280, dtype=dtype, device=self.device), - "time_ids": torch.randn(2 * batch_size, self.time_dim, dtype=dtype, device=self.device), + "text_embeds": torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device), + "time_ids": torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device), } }, ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py index 26c8450c57de9..6932c8056cf78 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -719,3 +719,228 @@ def configure(self): def __len__(self): return self.num_train_timesteps + + +# Modified from diffusers.schedulers.LCMScheduler +class LCMScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1.0, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, + ): + self.device = device + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.final_alpha_cumprod = self.alphas_cumprod[0] + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + self.num_train_timesteps = num_train_timesteps + self.clip_sample = clip_sample + self.clip_sample_range = clip_sample_range + self.steps_offset = steps_offset + self.prediction_type = prediction_type + self.thresholding = thresholding + self.timestep_spacing = timestep_spacing + self.timestep_scaling = timestep_scaling + self.original_inference_steps = original_inference_steps + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.sample_max_value = sample_max_value + + self._step_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + + @property + def step_index(self): + return self._step_index + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: int, + strength: int = 1.0, + ): + assert num_inference_steps <= self.num_train_timesteps + + self.num_inference_steps = num_inference_steps + original_steps = self.original_inference_steps + + assert original_steps <= self.num_train_timesteps + assert num_inference_steps <= original_steps + + # LCM Timesteps Setting + # Currently, only linear spacing is supported. + c = self.num_train_timesteps // original_steps + # LCM Training Steps Schedule + lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1 + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + # LCM Inference Steps Schedule + timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] + + self.timesteps = torch.from_numpy(timesteps.copy()).to(device=self.device, dtype=torch.long) + + self._step_index = None + + def get_scalings_for_boundary_condition_discrete(self, timestep): + self.sigma_data = 0.5 # Default: 0.5 + scaled_timestep = timestep * self.timestep_scaling + + c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Compute the predicted original sample x_0 based on the model parameterization + if self.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + elif self.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 5. Clip or threshold "predicted x_0" + if self.thresholding: + predicted_original_sample = self._threshold_sample(predicted_original_sample) + elif self.clip_sample: + predicted_original_sample = predicted_original_sample.clamp(-self.clip_sample_range, self.clip_sample_range) + + # 6. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if self.step_index != self.num_inference_steps - 1: + noise = torch.randn( + model_output.shape, device=model_output.device, dtype=denoised.dtype, generator=generator + ) + prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + else: + prev_sample = denoised + + # upon completion increase step index by one + self._step_index += 1 + + return (prev_sample,) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def configure(self): + pass + + def __len__(self): + return self.num_train_timesteps diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index ace75bfbae7cb..fac72be346b3d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -77,6 +77,12 @@ def teardown(self): self.engines = {} def get_cached_model_name(self, model_name): + # TODO(tianleiwu): save custom model to a directory named by its original model. + if model_name == "unetxl" and self.pipeline_info.custom_unet(): + model_name = "lcm_" + model_name + + # TODO: When we support original VAE, we shall save custom VAE to another directory. + if self.pipeline_info.is_inpaint(): model_name += "_inpaint" return model_name @@ -93,6 +99,7 @@ def get_engine_path(self, engine_dir, model_name, profile_id): def load_models(self, framework_model_dir: str): # Disable torch SDPA since torch 2.0.* cannot export it to ONNX + # TODO(tianleiwu): Test and remove it if this is not needed in Torch 2.1. if hasattr(torch.nn.functional, "scaled_dot_product_attention"): delattr(torch.nn.functional, "scaled_dot_product_attention") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py index faa3f8bfaabf1..31ede1ba901f2 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py @@ -68,6 +68,7 @@ def _infer( image_height, image_width, denoising_steps=30, + strength=0.3, guidance=5.0, seed=None, warmup=False, @@ -79,7 +80,6 @@ def _infer( crops_coords_top_left = (0, 0) target_size = (image_height, image_width) - strength = 0.3 aesthetic_score = 6.0 negative_aesthetic_score = 2.5 @@ -155,12 +155,12 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: print("SD-XL Refiner Pipeline") - self.print_summary(e2e_tic, e2e_toc, batch_size) - self.save_images(images, "img2img-xl", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, @@ -171,6 +171,7 @@ def run( image_width, denoising_steps=30, guidance=5.0, + strength=0.3, seed=None, warmup=False, return_type="image", @@ -213,6 +214,7 @@ def run( image_height, image_width, denoising_steps=denoising_steps, + strength=strength, guidance=guidance, seed=seed, warmup=warmup, @@ -226,6 +228,7 @@ def run( image_height, image_width, denoising_steps=denoising_steps, + strength=strength, guidance=guidance, seed=seed, warmup=warmup, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index e675c9a7b3bf5..a0b3c3a1c85b1 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -23,12 +23,13 @@ import os import pathlib import random +from typing import Any, Dict, List import nvtx import torch from cuda import cudart from diffusion_models import PipelineInfo, get_tokenizer -from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler +from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler from engine_builder import EngineType from engine_builder_ort_cuda import OrtCudaEngineBuilder from engine_builder_ort_trt import OrtTensorrtEngineBuilder @@ -63,7 +64,7 @@ def __init__( max_batch_size (int): Maximum batch size for dynamic batch engine. scheduler (str): - The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC]. + The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC, LCM]. device (str): PyTorch device to run inference. Default: 'cuda' output_dir (str): @@ -162,9 +163,11 @@ def set_scheduler(self, scheduler: str): elif scheduler == "EulerA": self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) elif scheduler == "UniPC": - self.scheduler = UniPCMultistepScheduler(device=self.device) + self.scheduler = UniPCMultistepScheduler(device=self.device, **sched_opts) + elif scheduler == "LCM": + self.scheduler = LCMScheduler(device=self.device, **sched_opts) else: - raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") + raise ValueError("Scheduler should be either DDIM, EulerA, UniPC or LCM") self.current_scheduler = scheduler self.denoising_steps = None @@ -238,6 +241,7 @@ def encode_prompt( pooled_outputs=False, output_hidden_states=False, force_zeros_for_empty_prompt=False, + do_classifier_free_guidance=True, ): if tokenizer is None: tokenizer = self.tokenizer @@ -265,41 +269,44 @@ def encode_prompt( if output_hidden_states: hidden_states = outputs["hidden_states"].clone() - # Note: negative prompt embedding is not needed for SD XL when guidance < 1 - - # For SD XL base, handle force_zeros_for_empty_prompt - is_empty_negative_prompt = all([not i for i in negative_prompt]) - if force_zeros_for_empty_prompt and is_empty_negative_prompt: - uncond_embeddings = torch.zeros_like(text_embeddings) - if output_hidden_states: - uncond_hidden_states = torch.zeros_like(hidden_states) - else: - # Tokenize negative prompt - uncond_input_ids = ( - tokenizer( - negative_prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", + # Note: negative prompt embedding is not needed for SD XL when guidance <= 1 + if do_classifier_free_guidance: + # For SD XL base, handle force_zeros_for_empty_prompt + is_empty_negative_prompt = all([not i for i in negative_prompt]) + if force_zeros_for_empty_prompt and is_empty_negative_prompt: + uncond_embeddings = torch.zeros_like(text_embeddings) + if output_hidden_states: + uncond_hidden_states = torch.zeros_like(hidden_states) + else: + # Tokenize negative prompt + uncond_input_ids = ( + tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) ) - .input_ids.type(torch.int32) - .to(self.device) - ) - outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) - uncond_embeddings = outputs["text_embeddings"] - if output_hidden_states: - uncond_hidden_states = outputs["hidden_states"] + outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) + uncond_embeddings = outputs["text_embeddings"] + if output_hidden_states: + uncond_hidden_states = outputs["hidden_states"] - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) if pooled_outputs: pooled_output = text_embeddings if output_hidden_states: - text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + if do_classifier_free_guidance: + text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + else: + text_embeddings = hidden_states.to(dtype=torch.float16) cudart.cudaEventRecord(self.events["clip-stop"], 0) if self.nvtx_profile: @@ -321,7 +328,7 @@ def denoise_latent( guidance=7.5, add_kwargs=None, ): - assert guidance > 1.0, "Guidance has to be > 1.0" # TODO: remove this constraint + do_classifier_free_guidance = guidance > 1.0 cudart.cudaEventRecord(self.events["denoise-start"], 0) if not isinstance(timesteps, torch.Tensor): @@ -332,7 +339,7 @@ def denoise_latent( nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, step_offset + step_index, timestep @@ -366,11 +373,14 @@ def denoise_latent( nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) if type(self.scheduler) == UniPCMultistepScheduler: latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + elif type(self.scheduler) == LCMScheduler: + latents = self.scheduler.step(noise_pred, timestep, latents, generator=self.generator)[0] else: latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep) @@ -406,38 +416,42 @@ def decode_latent(self, latents): nvtx.end_range(nvtx_vae) return images - def print_summary(self, tic, toc, batch_size, vae_enc=False): + def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: + throughput = batch_size / (toc - tic) + latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] + latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1] + latency_vae = cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] + latency_vae_encoder = ( + cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1] + if vae_enc + else None + ) + latency = (toc - tic) * 1000.0 + print("|------------|--------------|") print("| {:^10} | {:^12} |".format("Module", "Latency")) print("|------------|--------------|") if vae_enc: - print( - "| {:^10} | {:>9.2f} ms |".format( - "VAE-Enc", - cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1], - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "CLIP", cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "UNet x " + str(self.actual_steps), - cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1], - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "VAE-Dec", cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] - ) - ) + print("| {:^10} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder)) + print("| {:^10} | {:>9.2f} ms |".format("CLIP", latency_clip)) + print("| {:^10} | {:>9.2f} ms |".format("UNet x " + str(self.actual_steps), latency_unet)) + print("| {:^10} | {:>9.2f} ms |".format("VAE-Dec", latency_vae)) print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("Pipeline", (toc - tic) * 1000.0)) + print("| {:^10} | {:>9.2f} ms |".format("Pipeline", latency)) print("|------------|--------------|") - print(f"Throughput: {batch_size / (toc - tic):.2f} image/s") + print(f"Throughput: {throughput:.2f} image/s") + + perf_data = { + "latency_clip": latency_clip, + "latency_unet": latency_unet, + "latency_vae": latency_vae, + "latency": latency, + "throughput": throughput, + } + if vae_enc: + perf_data["latency_vae_encoder"] = latency_vae_encoder + return perf_data @staticmethod def to_pil_image(images): @@ -449,26 +463,31 @@ def to_pil_image(images): return [Image.fromarray(images[i]) for i in range(images.shape[0])] - def save_images(self, images, pipeline, prompt): - image_name_prefix = ( - pipeline + "".join(set(["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))])) + "-" - ) + def metadata(self) -> Dict[str, Any]: + return { + "actual_steps": self.actual_steps, + "seed": self.get_current_seed(), + "name": self.pipeline_info.name(), + "custom_vae": self.pipeline_info.custom_fp16_vae(), + "custom_unet": self.pipeline_info.custom_unet(), + } + def save_images(self, images: List, prompt: List[str], negative_prompt: List[str], metadata: Dict[str, Any]): images = self.to_pil_image(images) - random_session_id = str(random.randint(1000, 9999)) + session_id = str(random.randint(1000, 9999)) for i, image in enumerate(images): seed = str(self.get_current_seed()) - image_path = os.path.join( - self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + "-" + seed + ".png" - ) + prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20] + parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)] + image_path = os.path.join(self.output_dir, "-".join(parts) + ".png") print(f"Saving image {i+1} / {len(images)} to: {image_path}") from PIL import PngImagePlugin - metadata = PngImagePlugin.PngInfo() - metadata.add_text("prompt", prompt[i]) - metadata.add_text("batch_size", str(len(images))) - metadata.add_text("denoising_steps", str(self.denoising_steps)) - metadata.add_text("actual_steps", str(self.actual_steps)) - metadata.add_text("seed", seed) - image.save(image_path, "PNG", pnginfo=metadata) + info = PngImagePlugin.PngInfo() + for k, v in metadata.items(): + info.add_text(k, str(v)) + info.add_text("prompt", prompt[i]) + info.add_text("negative_prompt", negative_prompt[i]) + + image.save(image_path, "PNG", pnginfo=info) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py index b9759b44e7635..87ce85af247a5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -84,11 +84,11 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: - self.print_summary(e2e_tic, e2e_toc, batch_size) - self.save_images(images, "txt2img", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py index 1b3be143e6ce7..8ed7e20e94c07 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -62,7 +62,7 @@ def _infer( return_type="image", ): assert len(prompt) == len(negative_prompt) - + do_classifier_free_guidance = guidance > 1.0 original_size = (image_height, image_width) crops_coords_top_left = (0, 0) target_size = (image_height, image_width) @@ -91,6 +91,7 @@ def _infer( tokenizer=self.tokenizer, output_hidden_states=True, force_zeros_for_empty_prompt=True, + do_classifier_free_guidance=do_classifier_free_guidance, ) # CLIP text encoder 2 text_embeddings2, pooled_embeddings2 = self.encode_prompt( @@ -101,6 +102,7 @@ def _infer( pooled_outputs=True, output_hidden_states=True, force_zeros_for_empty_prompt=True, + do_classifier_free_guidance=do_classifier_free_guidance, ) # Merged text embeddings @@ -111,9 +113,10 @@ def _infer( original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype ) add_time_ids = add_time_ids.repeat(batch_size, 1) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0).to(self.device) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids.to(self.device)} # UNet denoiser latents = self.denoise_latent( @@ -133,13 +136,12 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: print("SD-XL Base Pipeline") - self.print_summary(e2e_tic, e2e_toc, batch_size) - if return_type != "latent": - self.save_images(images, "txt2img-xl", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index a00e25ddd983f..63fa8acfbcc95 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -1,8 +1,8 @@ -diffusers==0.19.3 -transformers==4.31.0 +diffusers==0.23.1 +transformers==4.35.1 numpy>=1.24.1 accelerate -onnx==1.14.0 +onnx==1.14.1 coloredlogs packaging # Use newer version of protobuf might cause crash diff --git a/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc b/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc index ac213f70b1272..1c6721fed05a2 100644 --- a/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc +++ b/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc @@ -49,7 +49,9 @@ TEST(KernelTypeStrResolverUtilsTest, VerifyLayoutTransformationRequiredOpsResolv #endif // !defined(DISABLE_CONTRIB_OPS) } -// run this test manually to output a hard-coded byte array +// run this test manually to output a hard-coded byte array. +// update AddLayoutTransformationRequiredOpsToKernelTypeStrResolver in +// onnxruntime/core/framework/kernel_type_str_resolver_utils.cc TEST(KernelTypeStrResolverUtilsTest, DISABLED_PrintExpectedLayoutTransformationRequiredOpsResolverByteArray) { #if defined(DISABLE_CONTRIB_OPS) FAIL() << "Contrib ops must be enabled."; diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 2008d96539dca..5cb4633dadd46 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -466,11 +466,11 @@ GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index, bool use_cont } template -GetQDQTestCaseFn BuildQDQSplitTestCase( - const std::vector& input_shape, - const int64_t& axis, - bool use_contrib_qdq = false) { - return [input_shape, axis, use_contrib_qdq](ModelTestBuilder& builder) { +GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, + const int64_t& axis, + bool use_diff_output_scale, + bool use_contrib_qdq = false) { + return [input_shape, axis, use_diff_output_scale, use_contrib_qdq](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput(input_shape, std::numeric_limits::min(), std::numeric_limits::max()); @@ -478,16 +478,30 @@ GetQDQTestCaseFn BuildQDQSplitTestCase( InputType dq_zp = std::numeric_limits::max() / 2; OutputType q_zp = std::numeric_limits::max() / 2; auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input_arg, .003f, dq_zp, dq_output, use_contrib_qdq); + constexpr float input_scale = 0.003f; + builder.AddDequantizeLinearNode(input_arg, input_scale, dq_zp, dq_output, use_contrib_qdq); // add Split + std::vector split_inputs; + split_inputs.push_back(dq_output); + + // Use the optional 'split' input when testing Split 13 + int opset = builder.DomainToVersionMap().find(kOnnxDomain)->second; + if (opset >= 13 && opset < 18) { + int64_t dim = input_shape[axis]; + int64_t split_size = dim / 3; + split_inputs.push_back(builder.Make1DInitializer(std::vector{split_size, + split_size, dim - (2 * split_size)})); + } auto* split_output_1 = builder.MakeIntermediate(); auto* split_output_2 = builder.MakeIntermediate(); auto* split_output_3 = builder.MakeIntermediate(); - Node& split_node = builder.AddNode("Split", {dq_output}, {split_output_1, split_output_2, split_output_3}); + Node& split_node = builder.AddNode("Split", split_inputs, {split_output_1, split_output_2, split_output_3}); split_node.AddAttribute("axis", axis); - if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) { + + // Use the 'num_outputs' attribute when testing Split >= 18 + if (opset >= 18) { split_node.AddAttribute("num_outputs", static_cast(3)); } @@ -495,11 +509,12 @@ GetQDQTestCaseFn BuildQDQSplitTestCase( auto* q_split_output_1 = builder.MakeOutput(); auto* q_split_output_2 = builder.MakeOutput(); auto* q_split_output_3 = builder.MakeOutput(); - builder.AddQuantizeLinearNode(split_output_1, .003f, q_zp, q_split_output_1, + float output_scale = use_diff_output_scale ? input_scale + 0.001f : input_scale; + builder.AddQuantizeLinearNode(split_output_1, output_scale, q_zp, q_split_output_1, use_contrib_qdq); // Model input (node_token_1) - builder.AddQuantizeLinearNode(split_output_2, .003f, q_zp, q_split_output_2, + builder.AddQuantizeLinearNode(split_output_2, output_scale, q_zp, q_split_output_2, use_contrib_qdq); // Model input (node_token_2) - builder.AddQuantizeLinearNode(split_output_3, .003f, q_zp, q_split_output_3, + builder.AddQuantizeLinearNode(split_output_3, output_scale, q_zp, q_split_output_3, use_contrib_qdq); }; } @@ -549,13 +564,30 @@ GetQDQTestCaseFn BuildQDQTransposeTestCase( InputType dq_zp = std::numeric_limits::max() / 2; OutputType q_zp = std::numeric_limits::max() / 2; - // add DQ - auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input_arg, .003f, dq_zp, dq_output, use_contrib_qdq); + // In order to test additional EPs that are more sensitive to whether the Transpose is in a QDQ node unit or not, + // we need a QDQ node unit prior to DQ -> Transpose -> Q -> graph output. + // The transpose optimizer will push the transpose, convert its input to uint8, and drop the empty DQ -> Q. + // If there's a QDQ node unit prior, the scale and zp info can be read from the Q node feeding the standalone + // Transpose node, so we add a DQ -> Mul -> Q to provide that. + // Essentially eveything has worked correctly if the DQ -> Transpose -> Q becomes a single Transpose and the + // extra QDQ node unit simply allows some additional functionality to be tested. + + // add DQ -> Mul -> Q + auto* dq_output_0 = builder.MakeIntermediate(); + auto* mul_output = builder.MakeIntermediate(); + auto* q_output_0 = builder.MakeIntermediate(); + auto mul_by = builder.MakeInitializer({1}, 2.f, 3.f); + builder.AddDequantizeLinearNode(input_arg, .003f, dq_zp, dq_output_0, use_contrib_qdq); + builder.AddNode("Mul", {dq_output_0, mul_by}, {mul_output}); + builder.AddQuantizeLinearNode(mul_output, .003f, q_zp, q_output_0, use_contrib_qdq); + + // add DQ -> Transpose -> Q + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(q_output_0, .003f, dq_zp, dq_output_1, use_contrib_qdq); // add Transpose auto* transpose_output = builder.MakeIntermediate(); - Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {transpose_output}); + Node& transpose_node = builder.AddNode("Transpose", {dq_output_1}, {transpose_output}); transpose_node.AddAttribute("perm", perms); // add Q diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1bf1cbacf479e..6b0f837c14b5a 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -1187,21 +1187,32 @@ static void RunDoubleQDQWithoutLastNodeBeingOutput(int output_index, int expecte TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { constexpr bool use_contrib_qdq = true; // For readability. - RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2); - RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); - RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); - RunDoubleQDQWithoutLastNodeBeingOutput(0, 2, 2, use_contrib_qdq); - - // EnsureUniqueDQForNodeUnit will duplicate first DQ, so expected one more (3) - RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3); - RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); - RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); - RunDoubleQDQWithoutLastNodeBeingOutput(1, 2, 3, use_contrib_qdq); + // the first node being a graph output doesn't prevent the DQ -> Q in the middle from being removed + // if they have matching type/scale/zp + // Q -> DQ -> Q -> DQ + // `-> graph output + RunDoubleQDQWithoutLastNodeBeingOutput(0, 1, 1); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 1, 1, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 1, 1, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(0, 1, 1, use_contrib_qdq); + + // EnsureUniqueDQForNodeUnit will duplicate first DQ, but after that the DQ -> Q in the middle can still be removed + // leaveing one Q and 2 DQ. + // Q -> DQ -> Q -> DQ + // `-> graph output + // => + // Q -> DQ -> Q -> DQ + // `-> DQ -> graph output + RunDoubleQDQWithoutLastNodeBeingOutput(1, 1, 2); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 1, 2, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 1, 2, use_contrib_qdq); + RunDoubleQDQWithoutLastNodeBeingOutput(1, 1, 2, use_contrib_qdq); RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2); RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2, use_contrib_qdq); RunDoubleQDQWithoutLastNodeBeingOutput(2, 2, 2, use_contrib_qdq); + // last node being a graph output doesn't prevent the DQ -> Q in the middle from being removed RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1); RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1, use_contrib_qdq); RunDoubleQDQWithoutLastNodeBeingOutput(3, 1, 1, use_contrib_qdq); @@ -1210,27 +1221,51 @@ TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { // Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. template static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis, - bool use_contrib_qdq = false) { - auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + bool all_same_quant_params, bool use_contrib_qdq = false) { + auto check_graph = [all_same_quant_params, use_contrib_qdq](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + int expected_q_ops = all_same_quant_params ? 0 : 3; + int expected_dq_ops = all_same_quant_params ? 0 : 1; EXPECT_EQ(op_to_count["Split"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_q_ops); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_dq_ops); }; - TransformerTester(BuildQDQSplitTestCase(input_shape, axis, use_contrib_qdq), + TransformerTester(BuildQDQSplitTestCase(input_shape, axis, !all_same_quant_params, + use_contrib_qdq), check_graph, TransformerLevel::Level1, TransformerLevel::Level2, - {12, 18, 19}); + {12, 13, 18, 19}); // Test different ways to specify the split in each opset: + // 12 - split into equal parts without explicit 'split' attribute + // 13 - use optional 'split' input to split into 3 parts + // 18 - use 'num_outputs' attribute to split into 3 parts + // 19 - use 'num_outputs' attribute to split into 3 parts } // Test that DQ -> Split -> Q (many) is replaced with just Split for various quantization types. TEST(QDQTransformerTests, Split) { - RunDropSplitQDQTestCase({6, 18, 54}, 0); - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int8 QDQ ops - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int16 QDQ ops - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft uint16 QDQ ops + // Test cases that drop Q/DQ ops from DQ -> Split -> Q (many). + // This happens when all the Q/DQ ops have equal and constant quantization parameters. + { + constexpr bool ALL_SAME_QUANT_PARAMS = true; + constexpr bool USE_CONTRIB_QDQ_OPS = true; + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + } + + // Test cases that DO NOT drop Q/DQ ops from DQ -> Split -> Q (many) + // This happens when the Q/DQ ops do not have equal and constant quantization parameters. + { + constexpr bool DIFF_QUANT_PARAMS = false; + constexpr bool USE_CONTRIB_QDQ_OPS = true; + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + } } // Because split isn't one the supported ops, this will stay the same @@ -1296,12 +1331,15 @@ TEST(QDQTransformerTests, Where) { template static void RunDropQDQTransposeTestCase(const std::vector& input_shape, const std::vector& perms, bool use_contrib_qdq = false) { + // model has DQ -> Mul -> Q -> DQ -> Transpose -> Q -> output + // post transform and optimization it should be DQ -> Mul -> Q -> Transpose(uint8) -> output auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); EXPECT_EQ(op_to_count["Transpose"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["Mul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); }; TransformerTester(BuildQDQTransposeTestCase(input_shape, perms, use_contrib_qdq), @@ -3068,29 +3106,54 @@ TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { transpose_node.AddAttribute("perm", perms); }; + bool use_transpose_optimizer = false; + auto check_graph = [&](InferenceSessionWrapper& session) { - // transpose optimization will change the order of the nodes, - // but as we're testing there's no propagation of the DQ what matters is the op counts. - auto op_counts = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - EXPECT_EQ(op_counts[qdq_keys.dequantize_linear], 1); - EXPECT_EQ(op_counts["Transpose"], 1); + + // if the transpose optimizer isn't used the DQ doesn't propagate past the Transpose + // TODO: Should it? It makes it easier for an EP to do a quantized Tranpose if it's in a QDQ node unit as it + // doesn't have to special-case looking for a solo Transpose. + std::vector expected_op_types_in_order{qdq_keys.dequantize_linear, + "Transpose"}; + if (use_transpose_optimizer) { + // fixup of QDQ node units would have put the Transpose in a QDQ node unit for consistency IFF + // the scale and zero point inputs are constant (which they are here) + expected_op_types_in_order.push_back(qdq_keys.quantize_linear); + expected_op_types_in_order.push_back(qdq_keys.dequantize_linear); + } + + const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + EXPECT_EQ(op_types_in_order, expected_op_types_in_order); + + if (use_transpose_optimizer) { + // the trailing Q/DQ should have updated axis based on the transpose. default axis of 1 moves to 3 with + // transpose of {0,2,3,1} (NCHW -> NHWC) + GraphViewer graph_viewer{session.GetGraph()}; + const auto& ordered_nodes = graph_viewer.GetNodesInTopologicalOrder(); + const auto& q_node = *graph_viewer.GetNode(ordered_nodes.back() - 1); + const auto& dq_node = *graph_viewer.GetNode(ordered_nodes.back()); + + EXPECT_EQ(graph_utils::GetNodeAttribute(q_node, std::string("axis"))->i(), 3); + EXPECT_EQ(graph_utils::GetNodeAttribute(dq_node, std::string("axis"))->i(), 3); + } }; - TransformerTester(build_test_case, - check_graph, - TransformerLevel::Default, - TransformerLevel::Level1); - TransformerTester(build_test_case, - check_graph, - TransformerLevel::Default, - TransformerLevel::Level1, - 18); - TransformerTester(build_test_case, - check_graph, - TransformerLevel::Default, - TransformerLevel::Level1, - 19); + auto run_test = [&](int opset) { + use_transpose_optimizer = true; + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset); + + use_transpose_optimizer = false; + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset, + // defaults that we're not overriding + 0.0, 0.0, nullptr, {}, + // disable generic L1 and CPU EP specific L2 TransposeOptimizer + {"TransposeOptimizer", std::string("TransposeOptimizer_") + kCpuExecutionProvider}); + }; + + run_test(12); + run_test(18); + run_test(19); }; test_case({1, 13, 13, 23}, {0, 2, 3, 1}, false /*use_contrib_qdq*/); @@ -3293,10 +3356,9 @@ TEST(QDQTransformerTests, QDQPropagation_GH11605_Opset12_19) { // Original: DQ -> Tr -> SoftM -> Tr // QDQ Prop inserts a Q/DQ pair to create a QDQ node group for the Transpose: DQ -> Tr -> Q -> DQ -> SoftM -> Tr // Transpose opt phase 1 moves the Tr down until it blocks on the SoftMax: DQ -> Q -> DQ -> Tr -> SoftM -> Tr - // Transpose opt phase 2 flips the Tr to prior to the DQ as it's not part of a QDQ node group at that point, as - // running the transpose on 8-bit data should be cheaper: DQ -> Q -> Tr -> DQ -> SoftM -> Tr - // QDQ cleanup in Level2 removes the unnecessary DQ/Q pair at the start: Tr -> DQ -> SoftM -> Tr - // this is the optimal result as the Transpose is using 8-bit data and we have no surplus Q/DQ pairs + // Transpose opt phase 2 repairs the QDQ node units: DQ -> Q -> DQ -> Tr -> Q -> DQ -> SoftM -> TR + // and removes the unnecessary DQ/Q pair at the start: DQ -> Tr -> Q -> DQ -> SoftM -> Tr + // The L2 CPU EP QDQ handling converts the DQ -> Tr -> Q to a Transpose with 8-bit data. auto check_graph = [&](InferenceSessionWrapper& session) { const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); std::vector expected_op_types_in_order{ @@ -3305,8 +3367,13 @@ TEST(QDQTransformerTests, QDQPropagation_GH11605_Opset12_19) { "Softmax", "Transpose"}; - const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph(), true); + const auto& graph = session.GetGraph(); + GraphViewer graph_viewer(graph); + const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(graph, true); EXPECT_EQ(op_types_in_order, expected_op_types_in_order); + + auto first_node = graph_viewer.GetNode(graph_viewer.GetNodesInTopologicalOrder().front()); + EXPECT_EQ(*first_node->InputDefs()[0]->Type(), "tensor(uint8)"); }; TransformerTester(build_test_case, diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 4f4157bd7b1cf..a1649f9e6b588 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -3742,66 +3742,6 @@ TEST(TransposeOptimizerTests, TestDequantizeLinearNoAxis) { #endif } -// Utility function that runs TransformerTester for the graph in which a single DequantizeLinear node is -// the parent of two Transpose nodes. The DQ should be duplicated by EnsureUniqueDQForNodeUnit, and the -// Transposes should be pushed. -template -static void RunDequantizeLinearTransposePropagationTestCase(const std::string& dq_domain = "") { - auto build_test_case = [dq_domain](ModelTestBuilder& builder) { - auto* input0_arg = MakeInput(builder, {{2, -1, 6, 3}}, {2, 4, 6, 3}, 0, 5); - auto* scale_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {2.3f}); - auto* zero_point_arg = MakeInput(builder, {std::vector{}}, std::vector{}, {10}); - auto* dequantizelinear_1_out_0 = builder.MakeIntermediate(); - auto* transpose_1_out_0 = builder.MakeOutput(); - auto* transpose_2_out_0 = builder.MakeOutput(); - - builder.AddNode("DequantizeLinear", {input0_arg, scale_arg, zero_point_arg}, {dequantizelinear_1_out_0}, - dq_domain); - - auto& transpose_1 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_1_out_0}); - transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); - - auto& transpose_2 = builder.AddNode("Transpose", {dequantizelinear_1_out_0}, {transpose_2_out_0}); - transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); - }; - - auto check_graph = [dq_domain](InferenceSessionWrapper& session) { - const auto& graph = session.GetGraph(); - - const char* dq_count_key = (dq_domain == kMSDomain) ? "com.microsoft.DequantizeLinear" : "DequantizeLinear"; - const auto op_count = CountOpsInGraph(graph); - decltype(op_count) expected_op_count{ - {dq_count_key, 2}, // EnsureUniqueDQForNodeUnit should duplicate the original DQ - {"Transpose", 2}, - }; - ASSERT_EQ(op_count, expected_op_count); - - // Transposes should be pushed, so check for Transpose -> DQ edges - for (const auto& node : graph.Nodes()) { - if (node.OpType() == "Transpose") { - ASSERT_EQ(node.GetOutputEdgesCount(), static_cast(1)); - ASSERT_EQ(node.OutputEdgesBegin()->GetNode().OpType(), "DequantizeLinear"); - } - } - }; - - TransformerTester(build_test_case, - check_graph, - TransformerLevel::Default, - TransformerLevel::Level1, - /*opset_version*/ 10); -} - -TEST(TransposeOptimizerTests, TestDequantizeLinearTransposePropagation) { - RunDequantizeLinearTransposePropagationTestCase(); -#if !defined(DISABLE_CONTRIB_OPS) - // Use com.microsoft.DequantizeLinear - RunDequantizeLinearTransposePropagationTestCase(kMSDomain); - RunDequantizeLinearTransposePropagationTestCase(kMSDomain); - RunDequantizeLinearTransposePropagationTestCase(kMSDomain); -#endif -} - TEST(TransposeOptimizerTests, TestCast) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = MakeInput(builder, {{-1, 4, -1, 5}}, {2, 4, 6, 5}, -1, 5); @@ -4609,7 +4549,6 @@ static void CheckSharedInitializerHandling(bool broadcast) { std::vector fetches; SessionOptions so; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); // get results with no modifications to the model @@ -4641,11 +4580,16 @@ static void CheckSharedInitializerHandling(bool broadcast) { ASSERT_EQ(result.error_msg, std::nullopt); ASSERT_TRUE(result.graph_modified); ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); - std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Transpose"], 0) << "The Transpose nodes should have been pushed through and canceled out."; + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), "updated_model.onnx")); - ASSERT_STATUS_OK(graph.Resolve()); + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0) << "The Transpose nodes should have been pushed through or canceled out."; + if (broadcast) { + EXPECT_EQ(op_to_count["Unsqueeze"], 0) << "Any Unsqueeze nodes should have been canceled out."; + } ASSERT_STATUS_OK(session.Initialize()); ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); @@ -4671,5 +4615,142 @@ TEST(TransposeOptimizerTests, SharedInitializerHandling) { TEST(TransposeOptimizerTests, SharedInitializerHandlingBroadcast) { CheckSharedInitializerHandling(/*broadcast*/ true); } + +// Unit test where EstimateTransposeValueCost must look past a DQ -> Squeeze to see the Transponse of a shared +// initializer for the overall cost of pushing the Transpose throught the second Where to be negative. +TEST(TransposeOptimizerTests, SharedInitializerHandlingBroadcast2) { + auto model_uri = ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast2.onnx"); + + RandomValueGenerator random{123}; + std::vector cond_input_0_dims{3, 2}; + std::vector cond_input_1_dims{2, 3}; + std::vector cond_input_data = {true, false, false, true, true, false}; + + std::vector x_0_input_dims{3}; + std::vector x_1_input_dims{3}; + std::vector x_input_data_0 = random.Gaussian(x_0_input_dims, 0.0f, 1.0f); + std::vector x_input_data_1 = random.Gaussian(x_1_input_dims, 0.0f, 1.0f); + + OrtValue cond_input_0, cond_input_1, x_input_0, x_input_1; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], cond_input_0_dims, cond_input_data, + &cond_input_0); + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], cond_input_1_dims, cond_input_data, + &cond_input_1); + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], x_0_input_dims, x_input_data_0, + &x_input_0); + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], x_1_input_dims, x_input_data_1, + &x_input_1); + + NameMLValMap feeds{{"cond_in_0", cond_input_0}, + {"cond_in_1", cond_input_1}, + {"x_in_0", x_input_0}, + {"x_in_1", x_input_1}}; + + std::vector output_names{"output0"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + + // get results with no modifications to the model + { + so.graph_optimization_level = TransformerLevel::Default; // off + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + // we call the ONNX transpose optimizer directly to simplify the model required to exercise the shared initializer + // handling. this means we don't need to disable optimizers that might alter the graph before the + // transpose optimizer runs (at a minimum ConstantFolding, CommonSubexpressionElimination and ConstantSharing). + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + using namespace onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + // default optimization cost check + OptimizeResult result = Optimize(*api_graph); + + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); + + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), updated_model.onnx")); + + // Pushing the initial Transpose through the 2 Where nodes results in + // - x_in_0 needs Transpose and Unsqueeze to broadcast correctly into the first Where + // - y_quant is updated in-place to transposed layout and used in both Where nodes + // - x_in_1 needs Transpose and Unsqueeze to broadcast correctly into the second Where + // - cond_in_1 needs Transpose + // - as we're pushing a Transpose through the Add for one input, and undo-ing the Transpose on y_quant for + // the other input, we save 2 by adding 1 to cond_in_1 + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 3) << "The 2 X inputs and cond_in_1 should require transpose."; + EXPECT_EQ(op_to_count["Unsqueeze"], 2) << "The 2 X inputs should require Unsqueeze."; + + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} + +// model where layout transform results in transposing a non-const input that is broadcast. +// this inserts Unsqueeze -> Transpose between the input and the node. +// test that QDQ node units are created for Unsqueeze and Transpose by inserting Q->DQ pairs after them +TEST(TransposeOptimizerTests, QnnTransposeNonConstBroadcastInput) { + Status status; + auto model_uri = ORT_TSTR("testdata/layout_transform_nonconst_broadcast_input.onnx"); + + SessionOptions so; + + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Transpose"], 3) << "Should have Transpose on 2 inputs and one on output."; + + // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_EQ(node.GetExecutionProviderType(), expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + // all nodes should be in QDQ node units except the Cast on an input which was not in a QDQ unit + if (node.OpType() != "QuantizeLinear" && node.OpType() != "DequantizeLinear" && node.OpType() != "Cast") { + for (auto cur_input = node.InputNodesBegin(), end = node.InputNodesEnd(); cur_input != end; ++cur_input) { + EXPECT_EQ(cur_input->OpType(), "DequantizeLinear"); + } + + for (auto cur_output = node.OutputNodesBegin(), end = node.OutputNodesEnd(); cur_output != end; ++cur_output) { + EXPECT_EQ(cur_output->OpType(), "QuantizeLinear"); + } + } + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/platform/ios/ios_package_test/.gitignore b/onnxruntime/test/platform/apple/apple_package_test/.gitignore similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/.gitignore rename to onnxruntime/test/platform/apple/apple_package_test/.gitignore diff --git a/onnxruntime/test/platform/ios/ios_package_test/Podfile.template b/onnxruntime/test/platform/apple/apple_package_test/Podfile.template similarity index 52% rename from onnxruntime/test/platform/ios/ios_package_test/Podfile.template rename to onnxruntime/test/platform/apple/apple_package_test/Podfile.template index d2155660d73da..3d191d6fb1cc6 100644 --- a/onnxruntime/test/platform/ios/ios_package_test/Podfile.template +++ b/onnxruntime/test/platform/apple/apple_package_test/Podfile.template @@ -1,14 +1,34 @@ -platform :ios, '13.0' +def include_macos_target + if '@C_POD_NAME@' != 'onnxruntime-mobile-c' + return true + end + return false +end target 'ios_package_test' do # Comment the next line if you don't want to use dynamic frameworks use_frameworks! + platform :ios, '13.0' + target 'ios_package_testUITests' do inherit! :search_paths pod '@C_POD_NAME@', :podspec => '@C_POD_PODSPEC@' end +end +if include_macos_target + target 'macos_package_test' do + # Comment the next line if you don't want to use dynamic frameworks + use_frameworks! + + platform :osx, '11.0' + + target 'macos_package_testUITests' do + inherit! :search_paths + pod '@C_POD_NAME@', :podspec => '@C_POD_PODSPEC@' + end + end end # This is to prevent the pods to be code signed if enabled diff --git a/onnxruntime/test/platform/ios/ios_package_test/README.md b/onnxruntime/test/platform/apple/apple_package_test/README.md similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/README.md rename to onnxruntime/test/platform/apple/apple_package_test/README.md diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test.xcodeproj/project.pbxproj b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj similarity index 57% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test.xcodeproj/project.pbxproj rename to onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj index 151db693236f0..66dd772e5e40b 100644 --- a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test.xcodeproj/project.pbxproj +++ b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj @@ -14,6 +14,11 @@ 229E595926586B4A006E41AE /* sigmoid.ort in Resources */ = {isa = PBXBuildFile; fileRef = 229E595826586B4A006E41AE /* sigmoid.ort */; }; 22C1D8EA271A79FD002CEE67 /* ios_package_uitest_cpp_api.mm in Sources */ = {isa = PBXBuildFile; fileRef = 22C1D8E9271A79FD002CEE67 /* ios_package_uitest_cpp_api.mm */; }; 22C1D8EB271A7A06002CEE67 /* sigmoid.ort in Resources */ = {isa = PBXBuildFile; fileRef = 229E595826586B4A006E41AE /* sigmoid.ort */; }; + 51C316BD2B0881450033C70B /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 51C316BC2B0881450033C70B /* AppDelegate.m */; }; + 51C316C52B0881480033C70B /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 51C316C32B0881480033C70B /* Main.storyboard */; }; + 51C316C72B0881480033C70B /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 51C316C62B0881480033C70B /* main.m */; }; + 51C316DC2B0881490033C70B /* macos_package_uitest_cpp_api.mm in Sources */ = {isa = PBXBuildFile; fileRef = 51C316DB2B0881490033C70B /* macos_package_uitest_cpp_api.mm */; }; + 51C316E82B0892EE0033C70B /* sigmoid.ort in Resources */ = {isa = PBXBuildFile; fileRef = 229E595826586B4A006E41AE /* sigmoid.ort */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -24,6 +29,13 @@ remoteGlobalIDString = 229E591B265869BF006E41AE; remoteInfo = ios_package_test; }; + 51C316D82B0881490033C70B /* PBXContainerItemProxy */ = { + isa = PBXContainerItemProxy; + containerPortal = 229E5914265869BF006E41AE /* Project object */; + proxyType = 1; + remoteGlobalIDString = 51C316B82B0881450033C70B; + remoteInfo = macos_package_test; + }; /* End PBXContainerItemProxy section */ /* Begin PBXFileReference section */ @@ -37,6 +49,14 @@ 229E595826586B4A006E41AE /* sigmoid.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = sigmoid.ort; sourceTree = ""; }; 22C1D8DE271A79AF002CEE67 /* ios_package_testUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = ios_package_testUITests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; 22C1D8E9271A79FD002CEE67 /* ios_package_uitest_cpp_api.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_package_uitest_cpp_api.mm; sourceTree = ""; }; + 51C316B92B0881450033C70B /* macos_package_test.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = macos_package_test.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 51C316BB2B0881450033C70B /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 51C316BC2B0881450033C70B /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; }; + 51C316C42B0881480033C70B /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 51C316C62B0881480033C70B /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = ""; }; + 51C316C82B0881480033C70B /* macos_package_test.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = macos_package_test.entitlements; sourceTree = ""; }; + 51C316D72B0881490033C70B /* macos_package_testUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = macos_package_testUITests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; + 51C316DB2B0881490033C70B /* macos_package_uitest_cpp_api.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = macos_package_uitest_cpp_api.mm; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -54,6 +74,20 @@ ); runOnlyForDeploymentPostprocessing = 0; }; + 51C316B62B0881450033C70B /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; + 51C316D42B0881490033C70B /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; /* End PBXFrameworksBuildPhase section */ /* Begin PBXGroup section */ @@ -63,7 +97,10 @@ 229E595426586A77006E41AE /* models */, 229E591E265869BF006E41AE /* ios_package_test */, 22C1D8DF271A79AF002CEE67 /* ios_package_testUITests */, + 51C316BA2B0881450033C70B /* macos_package_test */, + 51C316DA2B0881490033C70B /* macos_package_testUITests */, 229E591D265869BF006E41AE /* Products */, + B49FE29C3625E88EDCCDD4BC /* Pods */, ); sourceTree = ""; }; @@ -72,6 +109,8 @@ children = ( 229E591C265869BF006E41AE /* ios_package_test.app */, 22C1D8DE271A79AF002CEE67 /* ios_package_testUITests.xctest */, + 51C316B92B0881450033C70B /* macos_package_test.app */, + 51C316D72B0881490033C70B /* macos_package_testUITests.xctest */, ); name = Products; sourceTree = ""; @@ -105,6 +144,33 @@ path = ios_package_testUITests; sourceTree = ""; }; + 51C316BA2B0881450033C70B /* macos_package_test */ = { + isa = PBXGroup; + children = ( + 51C316BB2B0881450033C70B /* AppDelegate.h */, + 51C316BC2B0881450033C70B /* AppDelegate.m */, + 51C316C32B0881480033C70B /* Main.storyboard */, + 51C316C62B0881480033C70B /* main.m */, + 51C316C82B0881480033C70B /* macos_package_test.entitlements */, + ); + path = macos_package_test; + sourceTree = ""; + }; + 51C316DA2B0881490033C70B /* macos_package_testUITests */ = { + isa = PBXGroup; + children = ( + 51C316DB2B0881490033C70B /* macos_package_uitest_cpp_api.mm */, + ); + path = macos_package_testUITests; + sourceTree = ""; + }; + B49FE29C3625E88EDCCDD4BC /* Pods */ = { + isa = PBXGroup; + children = ( + ); + path = Pods; + sourceTree = ""; + }; /* End PBXGroup section */ /* Begin PBXNativeTarget section */ @@ -143,6 +209,41 @@ productReference = 22C1D8DE271A79AF002CEE67 /* ios_package_testUITests.xctest */; productType = "com.apple.product-type.bundle.ui-testing"; }; + 51C316B82B0881450033C70B /* macos_package_test */ = { + isa = PBXNativeTarget; + buildConfigurationList = 51C316DF2B0881490033C70B /* Build configuration list for PBXNativeTarget "macos_package_test" */; + buildPhases = ( + 51C316B52B0881450033C70B /* Sources */, + 51C316B62B0881450033C70B /* Frameworks */, + 51C316B72B0881450033C70B /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = macos_package_test; + productName = macos_package_test; + productReference = 51C316B92B0881450033C70B /* macos_package_test.app */; + productType = "com.apple.product-type.application"; + }; + 51C316D62B0881490033C70B /* macos_package_testUITests */ = { + isa = PBXNativeTarget; + buildConfigurationList = 51C316E52B0881490033C70B /* Build configuration list for PBXNativeTarget "macos_package_testUITests" */; + buildPhases = ( + 51C316D32B0881490033C70B /* Sources */, + 51C316D42B0881490033C70B /* Frameworks */, + 51C316D52B0881490033C70B /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + 51C316D92B0881490033C70B /* PBXTargetDependency */, + ); + name = macos_package_testUITests; + productName = macos_package_testUITests; + productReference = 51C316D72B0881490033C70B /* macos_package_testUITests.xctest */; + productType = "com.apple.product-type.bundle.ui-testing"; + }; /* End PBXNativeTarget section */ /* Begin PBXProject section */ @@ -158,9 +259,16 @@ CreatedOnToolsVersion = 13.0; TestTargetID = 229E591B265869BF006E41AE; }; + 51C316B82B0881450033C70B = { + CreatedOnToolsVersion = 15.0.1; + }; + 51C316D62B0881490033C70B = { + CreatedOnToolsVersion = 15.0.1; + TestTargetID = 51C316B82B0881450033C70B; + }; }; }; - buildConfigurationList = 229E5917265869BF006E41AE /* Build configuration list for PBXProject "ios_package_test" */; + buildConfigurationList = 229E5917265869BF006E41AE /* Build configuration list for PBXProject "apple_package_test" */; compatibilityVersion = "Xcode 9.3"; developmentRegion = en; hasScannedForEncodings = 0; @@ -175,6 +283,8 @@ targets = ( 229E591B265869BF006E41AE /* ios_package_test */, 22C1D8DD271A79AF002CEE67 /* ios_package_testUITests */, + 51C316B82B0881450033C70B /* macos_package_test */, + 51C316D62B0881490033C70B /* macos_package_testUITests */, ); }; /* End PBXProject section */ @@ -198,6 +308,22 @@ ); runOnlyForDeploymentPostprocessing = 0; }; + 51C316B72B0881450033C70B /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 51C316C52B0881480033C70B /* Main.storyboard in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + 51C316D52B0881490033C70B /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 51C316E82B0892EE0033C70B /* sigmoid.ort in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; /* End PBXResourcesBuildPhase section */ /* Begin PBXSourcesBuildPhase section */ @@ -218,6 +344,23 @@ ); runOnlyForDeploymentPostprocessing = 0; }; + 51C316B52B0881450033C70B /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 51C316C72B0881480033C70B /* main.m in Sources */, + 51C316BD2B0881450033C70B /* AppDelegate.m in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + 51C316D32B0881490033C70B /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 51C316DC2B0881490033C70B /* macos_package_uitest_cpp_api.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; /* End PBXSourcesBuildPhase section */ /* Begin PBXTargetDependency section */ @@ -226,6 +369,11 @@ target = 229E591B265869BF006E41AE /* ios_package_test */; targetProxy = 22C1D8E4271A79AF002CEE67 /* PBXContainerItemProxy */; }; + 51C316D92B0881490033C70B /* PBXTargetDependency */ = { + isa = PBXTargetDependency; + target = 51C316B82B0881450033C70B /* macos_package_test */; + targetProxy = 51C316D82B0881490033C70B /* PBXContainerItemProxy */; + }; /* End PBXTargetDependency section */ /* Begin PBXVariantGroup section */ @@ -245,6 +393,14 @@ name = LaunchScreen.storyboard; sourceTree = ""; }; + 51C316C32B0881480033C70B /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 51C316C42B0881480033C70B /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; /* End PBXVariantGroup section */ /* Begin XCBuildConfiguration section */ @@ -300,6 +456,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; IPHONEOS_DEPLOYMENT_TARGET = 13.0; + MACOSX_DEPLOYMENT_TARGET = 11.0; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; @@ -353,6 +510,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; IPHONEOS_DEPLOYMENT_TARGET = 13.0; + MACOSX_DEPLOYMENT_TARGET = 11.0; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; SDKROOT = iphoneos; @@ -365,6 +523,7 @@ buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; INFOPLIST_FILE = ios_package_test/Info.plist; LD_RUNPATH_SEARCH_PATHS = ( @@ -373,7 +532,10 @@ ); PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.ios-package-test"; PRODUCT_NAME = "$(TARGET_NAME)"; - TARGETED_DEVICE_FAMILY = "1,2"; + SUPPORTED_PLATFORMS = "iphoneos iphonesimulator"; + SUPPORTS_MACCATALYST = NO; + SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; + TARGETED_DEVICE_FAMILY = 1; }; name = Debug; }; @@ -382,6 +544,7 @@ buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; INFOPLIST_FILE = ios_package_test/Info.plist; LD_RUNPATH_SEARCH_PATHS = ( @@ -390,7 +553,10 @@ ); PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.ios-package-test"; PRODUCT_NAME = "$(TARGET_NAME)"; - TARGETED_DEVICE_FAMILY = "1,2"; + SUPPORTED_PLATFORMS = "iphoneos iphonesimulator"; + SUPPORTS_MACCATALYST = NO; + SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; + TARGETED_DEVICE_FAMILY = 1; }; name = Release; }; @@ -398,6 +564,7 @@ isa = XCBuildConfiguration; buildSettings = { CLANG_CXX_LANGUAGE_STANDARD = "gnu++17"; + CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; GENERATE_INFOPLIST_FILE = YES; @@ -420,6 +587,7 @@ isa = XCBuildConfiguration; buildSettings = { CLANG_CXX_LANGUAGE_STANDARD = "gnu++17"; + CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; GENERATE_INFOPLIST_FILE = YES; @@ -438,10 +606,128 @@ }; name = Release; }; + 51C316E02B0881490033C70B /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CODE_SIGN_ENTITLEMENTS = macos_package_test/macos_package_test.entitlements; + CODE_SIGN_IDENTITY = "Apple Development"; + CODE_SIGN_STYLE = Automatic; + COMBINE_HIDPI_IMAGES = YES; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_TEAM = UBF8T346G9; + ENABLE_HARDENED_RUNTIME = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + INFOPLIST_KEY_NSMainStoryboardFile = Main; + INFOPLIST_KEY_NSPrincipalClass = NSApplication; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/../Frameworks", + ); + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 11.0; + MARKETING_VERSION = 1.0; + PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-test"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SDKROOT = macosx; + SWIFT_EMIT_LOC_STRINGS = YES; + }; + name = Debug; + }; + 51C316E12B0881490033C70B /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CODE_SIGN_ENTITLEMENTS = macos_package_test/macos_package_test.entitlements; + CODE_SIGN_IDENTITY = "Apple Development"; + CODE_SIGN_STYLE = Automatic; + COMBINE_HIDPI_IMAGES = YES; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_TEAM = UBF8T346G9; + ENABLE_HARDENED_RUNTIME = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + INFOPLIST_KEY_NSMainStoryboardFile = Main; + INFOPLIST_KEY_NSPrincipalClass = NSApplication; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/../Frameworks", + ); + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 11.0; + MARKETING_VERSION = 1.0; + PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-test"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SDKROOT = macosx; + SWIFT_EMIT_LOC_STRINGS = YES; + }; + name = Release; + }; + 51C316E62B0881490033C70B /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CODE_SIGN_IDENTITY = "Apple Development"; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_TEAM = UBF8T346G9; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GENERATE_INFOPLIST_FILE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 11.0; + MARKETING_VERSION = 1.0; + PRODUCT_BUNDLE_IDENTIFIER = "com.MS.macos-package-testUITests"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SDKROOT = macosx; + SWIFT_EMIT_LOC_STRINGS = NO; + TEST_TARGET_NAME = macos_package_test; + }; + name = Debug; + }; + 51C316E72B0881490033C70B /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CODE_SIGN_IDENTITY = "Apple Development"; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_TEAM = UBF8T346G9; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GENERATE_INFOPLIST_FILE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 11.0; + MARKETING_VERSION = 1.0; + PRODUCT_BUNDLE_IDENTIFIER = "com.MS.macos-package-testUITests"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SDKROOT = macosx; + SWIFT_EMIT_LOC_STRINGS = NO; + TEST_TARGET_NAME = macos_package_test; + }; + name = Release; + }; /* End XCBuildConfiguration section */ /* Begin XCConfigurationList section */ - 229E5917265869BF006E41AE /* Build configuration list for PBXProject "ios_package_test" */ = { + 229E5917265869BF006E41AE /* Build configuration list for PBXProject "apple_package_test" */ = { isa = XCConfigurationList; buildConfigurations = ( 229E5949265869C2006E41AE /* Debug */, @@ -468,6 +754,24 @@ defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; + 51C316DF2B0881490033C70B /* Build configuration list for PBXNativeTarget "macos_package_test" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 51C316E02B0881490033C70B /* Debug */, + 51C316E12B0881490033C70B /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 51C316E52B0881490033C70B /* Build configuration list for PBXNativeTarget "macos_package_testUITests" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 51C316E62B0881490033C70B /* Debug */, + 51C316E72B0881490033C70B /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; /* End XCConfigurationList section */ }; rootObject = 229E5914265869BF006E41AE /* Project object */; diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.xcworkspace/contents.xcworkspacedata similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test.xcodeproj/project.xcworkspace/contents.xcworkspacedata rename to onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.xcworkspace/contents.xcworkspacedata diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist rename to onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist diff --git a/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings new file mode 100644 index 0000000000000..0c67376ebacb4 --- /dev/null +++ b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings @@ -0,0 +1,5 @@ + + + + + diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test/AppDelegate.h b/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/AppDelegate.h similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test/AppDelegate.h rename to onnxruntime/test/platform/apple/apple_package_test/ios_package_test/AppDelegate.h diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test/AppDelegate.m b/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/AppDelegate.m similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test/AppDelegate.m rename to onnxruntime/test/platform/apple/apple_package_test/ios_package_test/AppDelegate.m diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test/Base.lproj/LaunchScreen.storyboard b/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/Base.lproj/LaunchScreen.storyboard similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test/Base.lproj/LaunchScreen.storyboard rename to onnxruntime/test/platform/apple/apple_package_test/ios_package_test/Base.lproj/LaunchScreen.storyboard diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test/Base.lproj/Main.storyboard b/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/Base.lproj/Main.storyboard similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test/Base.lproj/Main.storyboard rename to onnxruntime/test/platform/apple/apple_package_test/ios_package_test/Base.lproj/Main.storyboard diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test/Info.plist b/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/Info.plist similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test/Info.plist rename to onnxruntime/test/platform/apple/apple_package_test/ios_package_test/Info.plist diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_test/main.m b/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/main.m similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_test/main.m rename to onnxruntime/test/platform/apple/apple_package_test/ios_package_test/main.m diff --git a/onnxruntime/test/platform/ios/ios_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm rename to onnxruntime/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/AppDelegate.h b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/AppDelegate.h new file mode 100644 index 0000000000000..e7b3600a059cb --- /dev/null +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/AppDelegate.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// AppDelegate.h +// macos_package_test +// + +#import + +@interface AppDelegate : NSObject + +@end diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/AppDelegate.m b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/AppDelegate.m new file mode 100644 index 0000000000000..36d16491c63b1 --- /dev/null +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/AppDelegate.m @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// AppDelegate.h +// macos_package_test +// + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate + +- (void)applicationDidFinishLaunching:(NSNotification*)aNotification { + // Insert code here to initialize your application +} + +- (void)applicationWillTerminate:(NSNotification*)aNotification { + // Insert code here to tear down your application +} + +- (BOOL)applicationSupportsSecureRestorableState:(NSApplication*)app { + return YES; +} + +@end diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/Base.lproj/Main.storyboard b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000..1cddb62a02eb6 --- /dev/null +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/Base.lproj/Main.storyboard @@ -0,0 +1,719 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Default + + + + + + + Left to Right + + + + + + + Right to Left + + + + + + + + + + + Default + + + + + + + Left to Right + + + + + + + Right to Left + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/macos_package_test.entitlements b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/macos_package_test.entitlements new file mode 100644 index 0000000000000..18aff0ce43c20 --- /dev/null +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/macos_package_test.entitlements @@ -0,0 +1,10 @@ + + + + + com.apple.security.app-sandbox + + com.apple.security.files.user-selected.read-only + + + diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/main.m b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/main.m new file mode 100644 index 0000000000000..ee939ac3752c1 --- /dev/null +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_test/main.m @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// AppDelegate.h +// macos_package_test +// + +#import + +int main(int argc, const char* argv[]) { + @autoreleasepool { + // Setup code that might create autoreleased objects goes here. + } + return NSApplicationMain(argc, argv); +} diff --git a/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm new file mode 100644 index 0000000000000..613c6e545939f --- /dev/null +++ b/onnxruntime/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// macos_package_test_cpp_api.mm +// macos_package_test_cpp_api +// +// This file hosts the tests of ORT C++ API +// + +#import +#include +#include + +#if __has_include() +#define COREML_EP_AVAILABLE 1 +#else +#define COREML_EP_AVAILABLE 0 +#endif + +#if COREML_EP_AVAILABLE +#include +#endif + +void testSigmoid(const char* modelPath, bool useCoreML) { + // This is an e2e test for ORT C++ API + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "testCppAPI"); + + // initialize session options if needed + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + +#if COREML_EP_AVAILABLE + if (useCoreML) { + const uint32_t flags = COREML_FLAG_USE_CPU_ONLY; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags)); + } +#else + (void)useCoreML; +#endif + + Ort::Session session(env, modelPath, session_options); + + size_t input_tensor_size = 3 * 4 * 5; + float input_tensor_values[input_tensor_size]; + float expected_output_values[input_tensor_size]; + const char* input_node_names[] = {"x"}; + const char* output_node_names[] = {"y"}; + const int64_t input_node_dims[] = {3, 4, 5}; + + for (size_t i = 0; i < input_tensor_size; i++) { + input_tensor_values[i] = (float)i - 30; + expected_output_values[i] = 1.0f / (1 + exp(-input_tensor_values[i])); + } + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::Value input_tensor = + Ort::Value::CreateTensor(memory_info, input_tensor_values, input_tensor_size, input_node_dims, 3); + XCTAssert(input_tensor.IsTensor()); + + auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names, + &input_tensor, 1, output_node_names, 1); + XCTAssertEqual(output_tensors.size(), 1); + XCTAssert(output_tensors.front().IsTensor()); + + // Get pointer to output tensor float values + float* output_values = output_tensors.front().GetTensorMutableData(); + for (size_t i = 0; i < input_tensor_size; i++) { + XCTAssertEqualWithAccuracy(expected_output_values[i], output_values[i], 1e-6); + } +} + +@interface macos_package_testUITests : XCTestCase + +@end + +@implementation macos_package_testUITests + +- (void)setUp { + // Put setup code here. This method is called before the invocation of each test method in the class. + + // In UI tests it is usually best to stop immediately when a failure occurs. + self.continueAfterFailure = NO; + + // In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this. +} + +- (void)tearDown { + // Put teardown code here. This method is called after the invocation of each test method in the class. +} + +- (NSString*)getFilePath { + NSBundle* bundle = [NSBundle bundleForClass:[self class]]; + NSString* ns_model_path = [bundle pathForResource:@"sigmoid" ofType:@"ort"]; + XCTAssertNotNil(ns_model_path); + return ns_model_path; +} + +- (void)testCppAPI_Basic { + testSigmoid([self getFilePath].UTF8String, false /* useCoreML */); +} + +#if COREML_EP_AVAILABLE +- (void)testCppAPI_Basic_CoreML { + testSigmoid([self getFilePath].UTF8String, true /* useCoreML */); +} +#endif + +@end diff --git a/onnxruntime/test/platform/ios/ios_package_test/models/sigmoid.ort b/onnxruntime/test/platform/apple/apple_package_test/models/sigmoid.ort similarity index 100% rename from onnxruntime/test/platform/ios/ios_package_test/models/sigmoid.ort rename to onnxruntime/test/platform/apple/apple_package_test/models/sigmoid.ort diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index b94c17c3b0e24..6eb72255bdf9a 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -421,7 +421,7 @@ TEST(SoftmaxOperator, GH15949_regression_test) { {0.00032932f, 0.01798029f, 0.9816904f}); // disable TRT as it does not support axis=0 as used by the model - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCoreMLExecutionProvider}); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } } // namespace test diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 7712a0a5bf724..70a43d660decb 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -94,7 +94,7 @@ constexpr T ValueFromIdx(size_t idx) { } template -void SplitTestAxis0EqualSplit(bool use_opset_13 = false) { +void SplitTestAxis0EqualSplit() { SCOPED_TRACE(onnxruntime::MakeString("data type: ", utils::ToTensorProtoElementType())); constexpr int64_t axis = 0; @@ -117,11 +117,20 @@ void SplitTestAxis0EqualSplit(bool use_opset_13 = false) { {V(5), V(6), V(7), V(8)}}); + // BFloat16 added in opset 13 + if constexpr (!std::is_same_v) { + RunTest(axis, {}, input, outputs, + // TensorRT parser: Assertion failed: axis != BATCH_DIM + {kTensorrtExecutionProvider}, // is_tensorrt_supported + false, // expect_failure + false /*split_as_input*/); + } + RunTest(axis, {}, input, outputs, // TensorRT parser: Assertion failed: axis != BATCH_DIM {kTensorrtExecutionProvider}, // is_tensorrt_supported false, // expect_failure - use_opset_13); // split_as_input + true /*split_as_input*/); } } // namespace @@ -130,7 +139,7 @@ TEST(SplitOperatorTest, Axis0EqualSplit) { SplitTestAxis0EqualSplit(); SplitTestAxis0EqualSplit(); SplitTestAxis0EqualSplit(); - SplitTestAxis0EqualSplit(true); // BFloat16 added in opset 13 + SplitTestAxis0EqualSplit(); SplitTestAxis0EqualSplit(); SplitTestAxis0EqualSplit(); SplitTestAxis0EqualSplit(); @@ -162,8 +171,11 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}); + // TensorRT parser: Assertion failed: axis != BATCH_DIM RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); + // CoreML EP, etc. requires split to be an input. Same applies to below sets of tests. + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); } TEST(SplitOperatorTest, Axis0UnequalSplitString) { @@ -186,6 +198,7 @@ TEST(SplitOperatorTest, Axis0UnequalSplitString) { "e", "f", "g", "h"}}); // TensorRT parser: Assertion failed: axis != BATCH_DIM + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -205,7 +218,7 @@ TEST(SplitOperatorTest, Axis1EqualSplitFloat) { outputs.push_back({{2, 2}, {3.f, 4.f, 7.f, 8.f}}); - + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -226,6 +239,7 @@ TEST(SplitOperatorTest, Axis1EqualSplitString) { {"c", "d", "g", "h"}}); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -248,6 +262,7 @@ TEST(SplitOperatorTest, Axis1UnequalSplitFloat) { {4.f, 8.f}}); + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -270,6 +285,7 @@ TEST(SplitOperatorTest, Axis1UnequalSplitString) { {"d", "h"}}); + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -312,6 +328,7 @@ TEST(SplitOperatorTest, Axis2EqualSplit) { 17.f, 18.f, 23.f, 24.f}}); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -344,6 +361,9 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) { 16.f, 17.f, 18.f, 22.f, 23.f, 24.f}}); + // Note: temporarily marked qnn ep as excluded when running tests with split_as_input=true. + // TODO: Need to resolve to see if it's not supported or test case failure. + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true); RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -353,7 +373,7 @@ TEST(SplitOperatorTest, ZeroSizeInput) { ShapeAndFloatData input = CreateInput({0, 2}); - RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); } // test a split of a dimension that has leading and trailing dimensions @@ -377,6 +397,7 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}}); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -403,6 +424,7 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}}); + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -423,6 +445,7 @@ TEST(SplitOperatorTest, NegativeAxis) { {3.f, 4.f, 7.f, 8.f}}); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -439,6 +462,7 @@ TEST(SplitOperatorTest, InvalidAxis) { outputs.push_back({{1}, {0.f}}); + RunTest(axis, {}, input, outputs, {}, true, true, -1, true, "Invalid value of attribute 'axis'"); RunTest(axis, {}, input, outputs, {}, true, false, -1, true, "Invalid value of attribute 'axis'"); } @@ -459,6 +483,8 @@ TEST(SplitOperatorTest, SplitAttributeSumTooSmall) { outputs.push_back({{1, 2}, {1.f, 2.f}}); outputs.push_back({{2, 2}, {3.f, 4.f, 5.f, 6.f}}); + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, true, -1, true, + "[ShapeInferenceError] Mismatch between the sum of 'split'"); RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, false, -1, true, "[ShapeInferenceError] Mismatch between the sum of 'split'"); // TensorRT parser: Assertion failed: axis != BATCH_DIM } @@ -478,6 +504,8 @@ TEST(SplitOperatorTest, InvalidValueInSplitAttribute) { outputs.push_back({{1, 2}, {1.f, 2.f}}); outputs.push_back({{3, 2}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}); + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, true, -1, true, + "[ShapeInferenceError] Mismatch between number of splits"); RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, false, -1, true, "[ShapeInferenceError] Mismatch between number of splits"); // TensorRT parser: Assertion failed: axis != BATCH_DIM } @@ -654,7 +682,8 @@ TEST(SplitOperatorTest, MissingOptionalInputAdded) { {3.f, 4.f, 7.f, 8.f}}); - RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, -1, false, {}, false); + // CoreML EP does not support the case when split_is_input==true but missing providing the split as initializer. + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kCoreMLExecutionProvider}, false, true, -1, false, {}, false); } TEST(SplitOperatorTest, Split18_NumOutputs_EvenSplit) { @@ -677,6 +706,9 @@ TEST(SplitOperatorTest, Split18_NumOutputs_EvenSplit) { 7.f, 8.f}}); int64_t num_outputs = 2; +#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, true); +#endif RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, false); } @@ -703,6 +735,9 @@ TEST(SplitOperatorTest, Split18_NumOutputs_UnevenSplit) { outputs.push_back({{1, 2}, {9.f, 10.f}}); int64_t num_outputs = 3; +#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, true); +#endif RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false); } @@ -728,6 +763,10 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) { }; RunTest(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false, "Attribute `num_outputs` value cannot be lower than 1"); +#ifdef USE_COREML + RunTest(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, true, + "Attribute `num_outputs` value cannot be lower than 1"); +#endif outputs.clear(); outputs.push_back({{1, 2}, @@ -738,6 +777,10 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) { num_outputs = 3; RunTest(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false, "Invalid num_outputs value of 3. Size of dimension being split is 2"); +#ifdef USE_COREML + RunTest(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, true, + "Invalid num_outputs value of 3. Size of dimension being split is 2"); +#endif } TEST(SplitOperatorTest, Split18_NumOutputsEvenSplitAxis1) { @@ -755,6 +798,9 @@ TEST(SplitOperatorTest, Split18_NumOutputsEvenSplitAxis1) { int64_t num_outputs = 3; RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, false); +#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs); +#endif } TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) { @@ -772,6 +818,9 @@ TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) { outputs.push_back({{2, 1}, {3.f, 6.f}}); int64_t num_outputs = 2; +#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs); +#endif RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false); } diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc new file mode 100644 index 0000000000000..aba2b0b2cb4a4 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc @@ -0,0 +1,507 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/framework/float16.h" +#include "core/mickey/blk_q4/prepack_sm80.h" +#include "core/mlas/inc/mlas_q4.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +void prepack_weights_ref( + int rows, + int columns, + const MatrixRef& tensor_weight, + const MatrixRef& tensor_weight_prepacked) { + EXPECT_TRUE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns); + EXPECT_TRUE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2); + + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } +} + +template < + typename ScaleElementT, + typename Layout, + typename QuantBlocking> +void prepack_quant_scales_ref( + int rows, + int columns, + const MatrixRef& tensor_scale, + const MatrixRef& tensor_scale_prepacked) { + EXPECT_TRUE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn)); + EXPECT_TRUE(tensor_scale_prepacked.shape() == tensor_scale.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + FAIL() << "Scale prepack only supported for 16b gemm with (1,n) quantization blocking"; + } +} + +template +void prepack_quant_offsets_ref( + size_t rows, + size_t columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + // EXPECT_TRUE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn)); + EXPECT_TRUE(tensor_offset_prepacked.shape() == tensor_offset.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + FAIL() << "Offsets prepack only supported for 16b gemm with (1,n) quantization blocking"; + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); + tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); + tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); + tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); + } + } + } + } +} + +template +void testPrepack(int rows, int columns, bool has_offset = true) { + using ElementT = MLFloat16; + constexpr int block_size = 32; + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + ColumnMajorQuantBlocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + using LayoutQmeta = typename Base::LayoutQmeta; + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution<> dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + std::vector q_weights(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + std::vector q_scales(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + q_scales[i] = ElementT(((dis(gen) % 127) + 1) / 32.0f); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + std::vector q_zp(meta_shape.product()); + for (size_t i = 0; i < q_zp.size(); i++) { + q_zp[i] = dis(gen) % 16; + } + MatrixRef tensor_offset( + q_zp, meta_shape); + +#if 0 // debug + // Fill tensor_q_weight with the patterned data, easier to debug with print + int loop_val = 0; + int offset = 3; + for (int col_tile = 0; col_tile < tensor_q_weight.extent().column()/8; ++col_tile) { + for (int row_tile = 0; row_tile < tensor_q_weight.extent().row()/4; ++row_tile) { + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); + auto val = (loop_val + offset) % 256; + tensor_q_weight.at(weight_cord) = ElementW(val); + loop_val++; + if (loop_val == 256) { + loop_val = 0; + offset += 11; + } + } + } + } + } + for (int col = 0; col < tensor_scale.extent().column(); ++col){ + int c = col * QuantBlocking::kColumn; + for (int row = 0; row < tensor_scale.extent().row(); ++row){ + int r = row * QuantBlocking::kRow; + auto weight_cord = cutlass::make_Coord(r/2, c); + int w = 0; + if (r % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + tensor_scale.at({row, col}) = w; + tensor_offset.at({row, col}) = ElementQOffset(w); + } + } + + int fill_val = -512; + int factor = 1; + for (int col = 0; col < tensor_scale.extent().column(); ++col){ + for (int row = 0; row < tensor_scale.extent().row(); ++row){ + tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); + fill_val++; + if (fill_val == 512) { + fill_val = -512; + factor += 1; + } + } + } + +#endif // debug + + std::vector dequants(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B for reference + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + const uint8_t offset = has_offset ? tensor_offset.at(scale_cord) : 8; + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + } + } + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape( + block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols); + // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes + EXPECT_EQ(q_rows, q_weight_shape[0]); + EXPECT_EQ(q_cols, q_weight_shape[1]); + + // + // Quantization tool outputs: + // + std::vector o_elements(q_rows * q_cols); + MatrixRef tensor_o_elements(o_elements, q_weight_shape); + + std::vector o_scales(meta_shape.product()); + MatrixRef tensor_o_scales(o_scales, meta_shape); + + std::vector o_zp(((meta_shape[0] + 1) / 2) * meta_shape[1], true); + MatrixRef tensor_o_zp( + o_zp, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); + + MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, + tensor_dequant.data().data(), block_size, + ColumnMajorQuantBlocking, rows, columns, columns, nullptr); + for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { + for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { + EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) + << "quantized value mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + if (has_offset) { + uint8_t pair01 = tensor_o_zp.at(row / 2, col); + EXPECT_EQ(tensor_offset.at(row + 0, col), pair01 & 0xf) + << "quantized offset mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(tensor_offset.at(row + 1, col), pair01 >> 4) + << "quantized offset mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) + << "quantized scale mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) + << "quantized scale mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } + + // + // Now we just setup fp16 weights tensor_dequant, quantized weights tensor_q_weight, + // quantization scale tensor_scale and quantization offset tensor_offset. The above + // testing just make sure our test setup is consistent with quantization tool output. + // + // Next we test the prepack code + // + + std::vector packed_w_ref(q_weight_shape.product()); + MatrixRef tensor_packed_w_ref( + packed_w_ref, make_Position(rows, columns / 2)); + prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); + + std::vector packed_w(q_weight_shape.product()); + MatrixRef tensor_packed_w( + packed_w, make_Position(rows, columns / 2)); + Base::prepack_weights(rows, columns, o_elements, packed_w); + + for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) + << "prepacked weights mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + std::vector packed_scales_ref(meta_shape.product()); + MatrixRef tensor_packed_s_ref = + Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) + : tensor_scale; + if (Base::ShouldRearrangeMeta) { + prepack_quant_scales_ref( + rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); + } + + std::vector packed_scales(meta_shape.product()); + MatrixRef tensor_packed_s( + packed_scales, meta_shape); + Base::prepack_quant_scales(rows, columns, o_scales, packed_scales); + + for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) + << "prepacked scales mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + if (has_offset) { + std::vector packed_zp_ref(meta_shape.product()); + MatrixRef tensor_packed_zp_ref = + Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) + : tensor_offset; + if (Base::ShouldRearrangeMeta) { + prepack_quant_offsets_ref( + rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); + } + + std::vector packed_zp(meta_shape.product()); + MatrixRef tensor_packed_zp( + packed_zp, meta_shape); + Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp); + + for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) + << "prepacked offsets mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } +} + +// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 +TEST(BlkQ4_GEMM, PrepackSm80Test) { + testPrepack(32, 32); + testPrepack(32, 32, false); + testPrepack(32, 32); + testPrepack(32, 32, false); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128, false); + testPrepack(128, 32, false); + testPrepack(256, 256, false); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128, false); + testPrepack(128, 32, false); + testPrepack(256, 256, false); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index e024eafcd6572..3435bd71aa4b3 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -786,7 +786,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { // Check the Onnx skeleton file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // Check the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNN_8283143575221199085_1.bin")); + EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin")); // 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), @@ -806,6 +806,62 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { context_binary_file); } +// Run QDQ model on HTP 2 times +// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH +TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["qnn_context_cache_enable"] = "1"; + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + provider_options["qnn_context_cache_path"] = context_binary_file; + provider_options["qnn_context_embed_mode"] = "0"; + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + EXPECT_TRUE(std::filesystem::exists(context_bin)); + // Delete the Qnn context cache binary file + EXPECT_TRUE(std::filesystem::remove(context_bin)); + + // loads and run from Onnx skeleton file + Qnn context cache binary file + onnx::ModelProto model_proto; + onnxruntime::Model qnn_ctx_model; + // Load the QNN context cache model from path specified + ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); + std::string qnn_ctx_model_data; + model_proto.SerializeToString(&qnn_ctx_model_data); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + std::string provider_type = kCpuExecutionProvider; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + // Run QDQ model on HTP with 2 inputs // 1st run will generate the Qnn context cache onnx file // 2nd run will load and run from QDQ model + Qnn context cache model @@ -1219,6 +1275,28 @@ TEST_F(QnnHTPBackendTests, VariadicOp_Concat_2Inputs_2ndAxis) { 13, ExpectedEPNodeAssignment::All); } + +TEST_F(QnnHTPBackendTests, LpNormalization_u8_rank4) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQOpTest("LpNormalization", + {TestInputDef({1, 2, 2, 2}, false, input_data)}, + {utils::MakeAttribute("axis", static_cast(-1)), // Last axis + utils::MakeAttribute("p", static_cast(2))}, // Order 2 to map to QNN's L2Norm operator + 13, + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LpNormalization_u16_rank4) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQOpTest("LpNormalization", + {TestInputDef({1, 2, 2, 2}, false, input_data)}, + {utils::MakeAttribute("axis", static_cast(-1)), // Last axis + utils::MakeAttribute("p", static_cast(2))}, // Order 2 to map to QNN's L2Norm operator + 13, + ExpectedEPNodeAssignment::All, + kOnnxDomain, + true); +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc b/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc index 225649ef391b1..65db81e7f4013 100644 --- a/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc +++ b/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc @@ -9,8 +9,9 @@ #include "core/framework/utils.h" #include "core/graph/graph.h" #include "core/providers/xnnpack/xnnpack_execution_provider.h" -#include "core/session/onnxruntime_cxx_api.h" #include "core/session/inference_session.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/tensor_op_test_utils.h" #include "test/framework/test_utils.h" @@ -214,8 +215,13 @@ static void RunModelTestWithPath(const ORTCHAR_T* ort_model_path, const char* gr NameMLValMap feeds; feeds.insert(std::make_pair("input", ml_value_x)); + // XNNPACK supports int8 data + std::function so_updater = [](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsQDQIsInt8Allowed, "1")); + }; + auto ep = DefaultXnnpackExecutionProvider(); - RunAndVerifyOutputsWithEP(ort_model_path, graph_name, std::move(ep), feeds, params); + RunAndVerifyOutputsWithEP(ort_model_path, graph_name, std::move(ep), feeds, params, so_updater); } TEST(XnnpackEP, DISABLED_TestQDQConvU8U8) { // [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for QuantizeLinear(19) node with name 'node_token_12' @@ -254,8 +260,7 @@ TEST(XnnpackEP, DISABLED_TestQDQConvS8S8) { // [ONNXRuntimeError] : 9 : NOT_IM TEST(XnnpackEP, TestQDQConvS8S8_per_channel) { std::function graph_verify = [](const Graph& graph) -> void { - ASSERT_EQ(graph.NumberOfNodes(), 5) << "Transpose*2 + dq +q +qlinearconv " - "leaving 5 nodes."; + ASSERT_EQ(graph.NumberOfNodes(), 5) << "-> Q -> Transpose -> QLinearConv -> Transpose -> DQ."; }; const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv_qdq_s8s8_perchannel.onnx"; RunModelTestWithPath(ort_model_path, "xnnpack_qdq_test_graph_conv_s8s8_perchannel", graph_verify, 0.2f); diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py index 7dffad8f84c83..482a334b12b85 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py @@ -14,6 +14,7 @@ from numpy.testing import assert_allclose from onnx import TensorProto from onnx.checker import check_model +from onnx.defs import onnx_opset_version from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info from onnx.numpy_helper import from_array @@ -91,7 +92,10 @@ def get_model_gemm( ] nodes = [n for n in nodes if n is not None] graph = make_graph(nodes, "gemm", inputs, [d], inits) - onnx_model = make_model(graph, opset_imports=[make_opsetid("", 19)], ir_version=9) + opset_imports = [make_opsetid("", onnx_opset_version() - 1)] + if domain == "com.microsoft": + opset_imports.append(make_opsetid("com.microsoft", 1)) + onnx_model = make_model(graph, opset_imports=opset_imports, ir_version=9) if domain != "com.microsoft": check_model(onnx_model) return onnx_model @@ -268,7 +272,8 @@ def test_combinations(self, shapeA, shapeB, transA, transB): make_tensor_value_info("B", TensorProto.FLOAT, [None, None]), ], [make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])], - ) + ), + opset_imports=[make_opsetid("", 19), make_opsetid("com.microsoft", 1)], ) sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index f26b6297cdbda..eede1be05f85f 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -393,6 +393,9 @@ def check_qtype_by_node_type(testcase, model_to_check, check_list): model = onnx.load(model_to_check) elif isinstance(model_to_check, onnx.ModelProto): model = model_to_check + # NOTE: ONNX shape inference does not work on MS domain nodes. + # Therefore, this function cannot currently be used for graphs that contain ops such as + # com.microsoft.QuantizeLinear, which support 16-bit quantization. model = onnx.shape_inference.infer_shapes(model) value_infos = {vi.name: vi for vi in model.graph.value_info} value_infos.update({ot.name: ot for ot in model.graph.output}) diff --git a/onnxruntime/test/python/quantization/test_op_softmax.py b/onnxruntime/test/python/quantization/test_op_softmax.py index 8e6e4d4100348..3416198450137 100644 --- a/onnxruntime/test/python/quantization/test_op_softmax.py +++ b/onnxruntime/test/python/quantization/test_op_softmax.py @@ -43,6 +43,7 @@ def construct_model_conv_softmax( softmax_input_shape, softmax_attributes, output_shape, + add_ms_domain_opset=False, ): # (input) # \ @@ -74,11 +75,16 @@ def construct_model_conv_softmax( [identity_out, output_tensor], initializer=initializers, ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + opset_imports = [helper.make_opsetid("", 13)] + if add_ms_domain_opset: + opset_imports.append(helper.make_opsetid("com.microsoft", 1)) + + model = helper.make_model(graph, opset_imports=opset_imports) model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 + def quantize_softmax_test_qop(self, activation_type, weight_type, extra_options={}): # noqa: B006 np.random.seed(1) model_fp32_path = "softmax_fp32.onnx" self.construct_model_conv_softmax( @@ -91,11 +97,10 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): ) data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) - activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" - weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + activation_proto_qtype = activation_type.tensor_type + activation_type_str = str(activation_type) + weight_type_str = str(weight_type) model_q8_path = f"softmax_{activation_type_str}{weight_type_str}.onnx" - model_q8_qdq_path = f"softmax_qdq_{activation_type_str}{weight_type_str}.onnx" # Verify QOperator mode data_reader.rewind() @@ -138,11 +143,30 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_path, data_reader.get_next()) + def quantize_softmax_test_qdq(self, activation_type, weight_type, extra_options={}): # noqa: B006 + np.random.seed(1) + model_fp32_path = "softmax_fp32.onnx" + self.construct_model_conv_softmax( + model_fp32_path, + [1, 2, 26, 42], + [3, 2, 3, 3], + [1, 3, 24, 40], + {"axis": -2}, + [1, 3, 24, 40], + add_ms_domain_opset=extra_options.get("UseQDQContribOps", False), + ) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) + + activation_proto_qtype = activation_type.tensor_type + activation_type_str = str(activation_type) + weight_type_str = str(weight_type) + model_qdq_path = f"softmax_qdq_{activation_type_str}{weight_type_str}.onnx" + # Verify QDQ mode data_reader.rewind() quantize_static( model_fp32_path, - model_q8_qdq_path, + model_qdq_path, data_reader, quant_format=QuantFormat.QDQ, activation_type=activation_type, @@ -150,7 +174,7 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): extra_options=extra_options, ) - result_model = onnx.load(Path(model_q8_qdq_path)) + result_model = onnx.load(Path(model_qdq_path)) qnode_cnt = 0 dqnode_cnt = 0 softmax_cnt = 0 @@ -166,9 +190,15 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): self.assertEqual(3, qnode_cnt, f"Expected 3 QuantizeLinear nodes, found {qnode_cnt}") self.assertEqual(4, dqnode_cnt, f"Expected 4 DequantizeLinear nodes, found {dqnode_cnt}") self.assertEqual(1, softmax_cnt, f"Expected 1 Softmax node, found {softmax_cnt}") - if extra_options.get("ActivationSymmetric", False): - for tensor in result_model.graph.initializer: - if tensor.name in qnode_zeropoints: + for tensor in result_model.graph.initializer: + if tensor.name in qnode_zeropoints: + self.assertEqual( + tensor.data_type, + activation_proto_qtype, + f"QuantizeLinear zero-point must be of proto type {activation_proto_qtype}, " + f"but found {tensor.data_type} instead.", + ) + if extra_options.get("ActivationSymmetric", False): np_value = numpy_helper.to_array(tensor) self.assertEqual( 0, @@ -176,30 +206,52 @@ def quantize_softmax_test(self, activation_type, weight_type, extra_options={}): f"QuantizeLinear node zero point value must be 0, found {np_value} instead!", ) - qnode_io_qtypes = { - "QuantizeLinear": [ - ["i", 2, activation_proto_qtype], - ["o", 0, activation_proto_qtype], - ] - } - check_qtype_by_node_type(self, model_q8_qdq_path, qnode_io_qtypes) data_reader.rewind() - check_model_correctness(self, model_fp32_path, model_q8_qdq_path, data_reader.get_next()) + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) def test_quantize_softmax(self): - self.quantize_softmax_test(QuantType.QUInt8, QuantType.QUInt8) + self.quantize_softmax_test_qop(QuantType.QUInt8, QuantType.QUInt8) + self.quantize_softmax_test_qdq(QuantType.QUInt8, QuantType.QUInt8) def test_quantize_softmax_s8s8(self): - self.quantize_softmax_test( + self.quantize_softmax_test_qop( + QuantType.QInt8, + QuantType.QInt8, + ) + self.quantize_softmax_test_qdq( + QuantType.QInt8, + QuantType.QInt8, + ) + self.quantize_softmax_test_qop( QuantType.QInt8, QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, ) - self.quantize_softmax_test( + self.quantize_softmax_test_qdq( QuantType.QInt8, QuantType.QInt8, extra_options={"ActivationSymmetric": True}, ) + def test_quantize_softmax_qdq_u16u16(self): + self.quantize_softmax_test_qdq( + QuantType.QUInt16, + QuantType.QUInt16, + extra_options={"UseQDQContribOps": True}, + ) + + def test_quantize_softmax_qdq_s16s16(self): + self.quantize_softmax_test_qdq( + QuantType.QInt16, + QuantType.QInt16, + extra_options={"UseQDQContribOps": True}, + ) + self.quantize_softmax_test_qdq( + QuantType.QInt16, + QuantType.QInt16, + extra_options={"UseQDQContribOps": True, "ActivationSymmetric": True}, + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/layout_transform_nonconst_broadcast_input.onnx b/onnxruntime/test/testdata/layout_transform_nonconst_broadcast_input.onnx new file mode 100644 index 0000000000000..8682be9992c62 Binary files /dev/null and b/onnxruntime/test/testdata/layout_transform_nonconst_broadcast_input.onnx differ diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx index 9d82d68a40098..797584f10ab24 100644 Binary files a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx and b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx differ diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py index d4e3f7e8cbab6..d710c796fb0ad 100644 --- a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py +++ b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py @@ -53,8 +53,64 @@ def create_model(broadcast_weights: bool): return model +def create_model_with_Where(): # noqa 'Where' is the operator name + """ + Create a model to validate the logic to cancel out the Transpose -> Squeeze -> DQ between an updated shared + initializer and other usage. We need to use Where as we require more than 2 inputs. + The `condition` input will be having a Transpose pushed through it will have a negative cost. + The `X` input will have a positive cost which cancels out the negative value. + The `Y` input will be a shared initializer that is braodcast. If we don't find the Transpose to make the cost of it + negative we will not push the Transpose though. + + If we only have 2 inputs, the broadcast initializer will always cost less due to its smaller rank, meaning we don't + actually need to look for the Squeeze in that case. + """ + cond_0_shape = [3, 2] # transpose to 2, 3 + cond_1_shape = [2, 3] + x_0_shape = [3] # broadcast so Transpose goes through Where0 + x_1_shape = [3] # also broadcast + y_shape = [3] # should be transposed and broadcast to [3, 1] if we push the transpose through the Where + y_values = np.random.randn(3) + + graph = helper.make_graph( + name="graph", + inputs=[ + helper.make_tensor_value_info("cond_in_0", TensorProto.BOOL, cond_0_shape), + helper.make_tensor_value_info("cond_in_1", TensorProto.BOOL, cond_1_shape), + helper.make_tensor_value_info("x_in_0", TensorProto.FLOAT, x_0_shape), + helper.make_tensor_value_info("x_in_1", TensorProto.FLOAT, x_1_shape), + ], + initializer=[ + helper.make_tensor("y_quant", TensorProto.UINT8, y_shape, y_values.astype(np.uint8)), + helper.make_tensor("dq_scale0", TensorProto.FLOAT, [], [1.5]), + helper.make_tensor("dq_scale1", TensorProto.FLOAT, [], [0.5]), + ], + nodes=[ + # Transpose the cond input + helper.make_node("Transpose", ["cond_in_0"], ["cond_in_T"], perm=[1, 0]), + helper.make_node("DequantizeLinear", ["y_quant", "dq_scale0"], ["DQ0"], "DQ0"), + # first usage of shared initializer. simple so we know the Transpose can push through it + helper.make_node("Where", ["cond_in_T", "x_in_0", "DQ0"], ["Where0"], "Where0"), + helper.make_node("DequantizeLinear", ["y_quant", "dq_scale1"], ["DQ1"], "DQ1"), + helper.make_node("Add", ["x_in_1", "Where0"], ["Add0"], "Add0"), + # second usage of shared initializer. requires looking past the Squeeze to push the transpose through + helper.make_node("Where", ["cond_in_1", "Add0", "DQ1"], ["Where1"], "Where1"), + helper.make_node("Transpose", ["Where1"], ["output0"], perm=[1, 0]), + ], + outputs=[ + helper.make_tensor_value_info("output0", TensorProto.FLOAT, [3, 2]), + ], + ) + + model = helper.make_model(graph) + onnx.checker.check_model(model, full_check=True) + return model + + if __name__ == "__main__": model = create_model(broadcast_weights=False) onnx.save(model, "transpose_optimizer_shared_initializers.onnx") model = create_model(broadcast_weights=True) onnx.save(model, "transpose_optimizer_shared_initializers_broadcast.onnx") + model = create_model_with_Where() + onnx.save(model, "transpose_optimizer_shared_initializers_broadcast2.onnx") diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast2.onnx b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast2.onnx new file mode 100644 index 0000000000000..ad05fb70cb26e Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast2.onnx differ diff --git a/onnxruntime/test/util/include/test_utils.h b/onnxruntime/test/util/include/test_utils.h index eb072a134b924..48a71b8acb261 100644 --- a/onnxruntime/test/util/include/test_utils.h +++ b/onnxruntime/test/util/include/test_utils.h @@ -20,6 +20,7 @@ namespace onnxruntime { class Graph; +struct SessionOptions; namespace test { @@ -62,11 +63,13 @@ using ModelPathOrBytes = std::variant, // Run the model using the CPU EP to get expected output, comparing to the output when the 'execution_provider' // is enabled. +// session_options_updater can be used to update the SessionOptions the inference session is created with. void RunAndVerifyOutputsWithEP(ModelPathOrBytes model_path_or_bytes, std::string_view log_id, std::unique_ptr execution_provider, const NameMLValMap& feeds, - const EPVerificationParams& params = EPVerificationParams()); + const EPVerificationParams& params = EPVerificationParams(), + const std::function& session_options_updater = {}); // Tests model loading only. // This can be used to test EPs in builds where only loading (and not running) of a model is supported. diff --git a/onnxruntime/test/util/test_utils.cc b/onnxruntime/test/util/test_utils.cc index 43845a5052e36..5f1fdae72f031 100644 --- a/onnxruntime/test/util/test_utils.cc +++ b/onnxruntime/test/util/test_utils.cc @@ -132,11 +132,16 @@ static gsl::span GetModelBytes(ModelPathOrBytes model_path_or_b void RunAndVerifyOutputsWithEP(ModelPathOrBytes model_path_or_bytes, std::string_view log_id, std::unique_ptr execution_provider, const NameMLValMap& feeds, - const EPVerificationParams& params) { + const EPVerificationParams& params, + const std::function& session_options_updater) { std::vector model_data_buffer{}; const auto model_data = GetModelBytes(model_path_or_bytes, model_data_buffer); SessionOptions so; + if (session_options_updater) { + session_options_updater(so); + } + so.session_logid = log_id; RunOptions run_options; run_options.run_tag = so.session_logid; diff --git a/orttraining/orttraining/core/agent/training_agent.cc b/orttraining/orttraining/core/agent/training_agent.cc index 3b701fa8bf577..0b38a79cc21c9 100644 --- a/orttraining/orttraining/core/agent/training_agent.cc +++ b/orttraining/orttraining/core/agent/training_agent.cc @@ -1,11 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include +#include + #include "orttraining/core/agent/training_agent.h" #include "core/framework/utils.h" #include "core/framework/feeds_fetches_manager.h" #include "core/framework/partial_graph_execution_state.h" #include "core/framework/stream_execution_context.h" +#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" namespace onnxruntime { namespace training { @@ -25,7 +31,8 @@ TrainingAgent::TrainingAgent(InferenceSession& session, std::vector bw_feed_names; size_t break_point = 0; - auto& training_node_execution_order = session_state.GetGraphViewer().GetNodesInTopologicalOrder(session.GetSessionOptions().execution_order); + auto& training_node_execution_order = session_state.GetGraphViewer().GetNodesInTopologicalOrder( + session.GetSessionOptions().execution_order); for (auto node_index : training_node_execution_order) { if (session_state.GetKernel(node_index)->KernelDef().OpName() == "YieldOp") { auto& node = *(session_state.GetGraphViewer().GetGraph().GetNode(node_index)); @@ -89,7 +96,8 @@ void TrainingAgent::CreateAndInitializeFeedsFetchesManager(const SessionState& s const std::vector& feed_names, const std::vector& fetches_names, const std::vector& outputs_device_info, - std::unique_ptr& feeds_fetches_manager) { + std::unique_ptr& + feeds_fetches_manager) { ORT_THROW_IF_ERROR(FeedsFetchesManager::Create(feed_names, fetches_names, session_state.GetOrtValueNameIdxMap(), feeds_fetches_manager)); auto& fetch_info = feeds_fetches_manager->GetMutableFetchesDeviceCopyInfo(); @@ -100,5 +108,23 @@ void TrainingAgent::CreateAndInitializeFeedsFetchesManager(const SessionState& s ORT_ENFORCE(utils::InitializeFeedFetchCopyInfo(session_state, *feeds_fetches_manager) == Status::OK()); } +std::string TrainingAgent::GetSerializedORTModuleMemoryStat(std::string_view memory_optimization_config, + std::string_view recompute_probe_level, + std::map>& + cluster_id_combinations_to_saved_symbolic_byte_map) + const { + auto& session_state = inference_session_.GetSessionState(); + const OrtValueNameIdxMap& ortvalue_name_to_idx_map = session_state.GetOrtValueNameIdxMap(); + const SequentialExecutionPlan& p_seq_exec_plan = *session_state.GetExecutionPlan(); + return optimizer::memory_optimizer::GetSerializedORTModuleMemoryStat( + session_state.GetGraphViewer(), + memory_optimization_config, + recompute_probe_level, + *inference_session_.GetLogger(), + cluster_id_combinations_to_saved_symbolic_byte_map, + &ortvalue_name_to_idx_map, + &p_seq_exec_plan); +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/agent/training_agent.h b/orttraining/orttraining/core/agent/training_agent.h index b12f5e6d75ef1..37e5272f66e32 100644 --- a/orttraining/orttraining/core/agent/training_agent.h +++ b/orttraining/orttraining/core/agent/training_agent.h @@ -5,11 +5,15 @@ #include #include +#include +#include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/framework/framework_common.h" #include "core/session/inference_session.h" +#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" namespace onnxruntime { struct PartialGraphExecutionState; @@ -45,6 +49,11 @@ class TrainingAgent { const std::vector& outputs_device_info, std::unique_ptr& feeds_fetches_manager); + std::string GetSerializedORTModuleMemoryStat(std::string_view memory_optimization_config, + std::string_view recompute_probe_level, + std::map>& + cluster_id_combinations_to_saved_symbolic_byte_map) const; + private: // TrainingAgent runs on a InferenceSession under the hood InferenceSession& inference_session_; diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 73638e8ba62a0..2d75a02004ff2 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -470,7 +470,8 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev // Get the first two dims value of input_ids which is [batch_size, seq_len] NodeArg* first_two_dims_arg = GetDimsValue(graph, input_ids_arg, - CreateInitializerFromVector(graph, {2}, {0, 1}, graph.GenerateNodeArgName("first_two_indices")), + CreateInitializerFromVector(graph, {2}, {0, 1}, + graph.GenerateNodeArgName("first_two_indices")), *embedding_node); // Add flatten pattern to each input node of the subgraph diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer.cc index 88c786d693cae..834e5ebb5f6f3 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer.cc @@ -1,233 +1,84 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include +#include +#include +#include + #include "core/framework/random_seed.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" #include "orttraining/core/graph/recompute_graph_utils.h" #include "orttraining/core/optimizer/memory_optimizer.h" +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h" +#include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" +#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" namespace onnxruntime { namespace { -constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15; - -std::string TensorShapeProtoToString(const ONNX_NAMESPACE::TensorShapeProto* shape) { - std::ostringstream shape_oss; - if (shape != nullptr) { - for (int dim_index = 0; dim_index < shape->dim_size(); dim_index++) { - auto dim = shape->dim(dim_index); - if (utils::HasDimValue(dim)) { - shape_oss << dim.dim_value() << " x "; - } else { - shape_oss << dim.dim_param() << " x "; - } - } - } else { - shape_oss << "unknown"; - } - - return shape_oss.str(); -} - -int ParseIntValueFromString(std::string_view str) { - int int_value = 0; - auto result = std::from_chars(str.data(), str.data() + str.size(), int_value); - ORT_ENFORCE(result.ec != std::errc::invalid_argument, "Fail to convert to int from string: ", str); - return int_value; -} - -constexpr bool IsForwardPassOperator(ptrdiff_t op_order_in_topological_sort, ptrdiff_t boundary_op_order_in_topological_sort) { +constexpr bool IsForwardPassOperator(ptrdiff_t op_order_in_topological_sort, + ptrdiff_t boundary_op_order_in_topological_sort) { return op_order_in_topological_sort <= boundary_op_order_in_topological_sort; } -static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { - const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); - MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); - const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); - ORT_ENFORCE(nullptr != tensor_type_base); - MLDataType elt_type = tensor_type_base->GetElementType(); - return elt_type->Size(); -} - -// TODO(pengwa): extend this function to be more general. -float InputOutputSizeRatio(const Node* node) { - if (node->OpType().compare("Cast") == 0) { - const NodeArg* input = node->InputDefs()[0]; - const NodeArg* output = node->OutputDefs()[0]; - if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING || - output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { - return 1.0f; - } - const auto& ptype1 = input->Type(); - const auto& ptype2 = output->Type(); - float ratio = float(GetElementSize(ptype1)) / (float)GetElementSize(ptype2); - return ratio; - } - - return 1.0f; -} - } // namespace -Status MemoryOptimizer::ParseConfigFromString(const std::string& enable_memory_optimizer, +Status MemoryOptimizer::ParseConfigFromString(const std::string& memory_optimizer_config, const std::string& level) { - optimizer_config_ = enable_memory_optimizer; - if (!enable_memory_optimizer.empty()) { - const auto user_config_strs = utils::SplitString(enable_memory_optimizer, ","); - for (const auto& user_config_str : user_config_strs) { - const auto user_config = utils::SplitString(user_config_str, ":"); - ORT_RETURN_IF_NOT(user_config.size() == 3, - "User config should be in format of SubgraphStr:OptimizationType:RequestApplyCount."); - - const std::string subgraph_string_representation(user_config[0]); - int optimization_type_int = ParseIntValueFromString(user_config[1]); - int requested_apply_count = ParseIntValueFromString(user_config[2]); - ORT_RETURN_IF_NOT(optimization_type_int < static_cast(OptimizationType::TypeMax) && - optimization_type_int >= 0, - "Invalid optimization type specified for subgraph: ", - subgraph_string_representation); - - ORT_RETURN_IF_NOT(requested_apply_count == -1 || requested_apply_count >= 0, - "Invalid requested_apply_count specified for subgraph: ", requested_apply_count); - - // At this point, subgraph_string_representation is a pattern graph string representation. - pattern_subgraph_to_user_optimizer_config_map_[subgraph_string_representation] = - UserConfig{static_cast(optimization_type_int), requested_apply_count}; - } - } - - int probe_level = ParseIntValueFromString(level); - ORT_RETURN_IF_NOT(probe_level < static_cast(ProbeLevel::LevelMax) && probe_level >= 0, - "Invalid probe level specified: ", level); - recompute_probe_level_ = static_cast(probe_level); - - return Status::OK(); -} - -int64_t MemoryOptimizer::PrepareForTransformation(const Graph& graph, - ActivationUsedMap& fw_op_output_arg_used_map, - InlinedHashMap& - node_index_to_its_order_in_topological_sort_map) const { - fw_op_output_arg_used_map.clear(); - - GraphViewer graph_viewer(graph); - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); + optimizer_config_ = memory_optimizer_config; - // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. - ptrdiff_t yield_op_order_in_topological_sort = -1; - for (size_t i = 0; i < node_ids.size(); ++i) { - const Node* p_node = graph.GetNode(node_ids[i]); - if (p_node == nullptr) { /* skip removed nodes*/ - continue; - } - - if (p_node->OpType() == "YieldOp") { - yield_op_order_in_topological_sort = static_cast(i); - } - - node_index_to_its_order_in_topological_sort_map[p_node->Index()] = i; - } - - // If boundary op found, create forward op output arg used map. - if (yield_op_order_in_topological_sort >= 0) { - for (size_t i = 0; i < node_ids.size(); ++i) { - const Node* p_node = graph.GetNode(node_ids[i]); - if (p_node == nullptr /* skip removed nodes*/) { - continue; - } + ORT_RETURN_IF_ERROR(optimizer::memory_optimizer::ParseConfigFromString( + memory_optimizer_config, + pattern_subgraph_to_user_optimizer_config_map_)); - const Node& node = *p_node; - bool is_forward_op = IsForwardPassOperator(static_cast(i), yield_op_order_in_topological_sort); - if (!is_forward_op) { - continue; - } - - for (auto& output_arg : node.OutputDefs()) { - bool used_in_fw = false; - bool used_in_bw = false; - for (auto& consumer_node : graph.GetConsumerNodes(output_arg->Name())) { - size_t consumer_node_index_in_topological_order = - node_index_to_its_order_in_topological_sort_map.at(consumer_node->Index()); - if (IsForwardPassOperator(static_cast(consumer_node_index_in_topological_order), - yield_op_order_in_topological_sort)) { - used_in_fw = true; - } else { - used_in_bw = true; - } - } - fw_op_output_arg_used_map.insert({{output_arg->Name(), std::make_pair(used_in_fw, used_in_bw)}}); - } - } - } - - // Return whether boundary op is found or not. - return yield_op_order_in_topological_sort; -} - -Status MemoryOptimizer::GetStashedActivationCandidates(const Graph& graph, - const InlinedHashMap>& - fw_op_output_arg_used_map, - InlinedHashMap>& - candidate_output_args_map, - const logging::Logger& logger) const { - for (auto& kv : fw_op_output_arg_used_map) { - // used by fw and bw, then it is a candidates. - if (kv.second.first && kv.second.second) { - const Node* n = graph.GetProducerNode(kv.first); - ORT_ENFORCE(n, "Activation should have a producer node"); - size_t k = 0; - for (k = 0; k < n->OutputDefs().size(); ++k) { - if (n->OutputDefs()[k]->Name().compare(kv.first) == 0) { - break; - } - } - - candidate_output_args_map[n].push_back(k); - LOGS(logger, VERBOSE) << "Find candidate output named [" << kv.first << "] of Node " << n->Name() << "(" - << n->OpType() << ")"; - } - } + int probe_level = optimizer::memory_optimizer::ParseIntValueFromString(level); + ORT_RETURN_IF_NOT(probe_level < static_cast(optimizer::memory_optimizer::ProbeLevel::LevelMax) && + probe_level >= 0, + "Invalid probe level specified: ", level); + recompute_probe_level_ = static_cast(probe_level); return Status::OK(); } bool MemoryOptimizer::ModifyGraph(Graph& graph, - const InlinedHashMap& + const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, const InlinedHashMap>& candidate_output_args_map, const logging::Logger& logger, - int64_t boundary_op_order_in_topological_sort, - SubGraphStores& subgraph_stores, - Node* node) const { + ptrdiff_t boundary_op_order_in_topological_sort, + Node* node, + std::shared_ptr& node_plan, + std::shared_ptr& apply_context) + const { bool graph_is_modified = false; - if (subgraph_stores.SubGraphDescCount() == 0) { - return graph_is_modified; - } - - SubGraphStores::GraphInstanceInfo& sub_graph_instance_info = - subgraph_stores.GetSubGraphInstance(node); - - SubGraphDesc& subgraph_desc = subgraph_stores.GetSubGraphDesc(sub_graph_instance_info.second); - UserConfig user_config = subgraph_desc.user_optimizer_config; - int skip_count = (user_config.requested_count == -1) + int skip_count = (apply_context->requested_count == -1) ? 0 - : std::max(0, subgraph_desc.total_frequency - user_config.requested_count); + : std::max(0, apply_context->total_frequency - apply_context->requested_count); - subgraph_desc.skip_count += 1; + apply_context->skip_count += 1; - if (user_config.type != OptimizationType::None && subgraph_desc.skip_count > skip_count) { - subgraph_desc.applied_count += 1; + if (apply_context->skip_count > skip_count) { + apply_context->applied_count += 1; Node* replacement_node_ptr = nullptr; - LOGS(logger, WARNING) << "[Modify Graph] Node " << node->Name() << "(" << node->OpType() << ") is " - << UserConfigToString(user_config); - if (user_config.type == OptimizationType::Recompute) { - ORT_ENFORCE(CreateRecomputeGraph(graph, sub_graph_instance_info.first, replacement_node_ptr).IsOK()); + LOGS(logger, INFO) << "Node " << node->Name() << "(" << node->OpType() << ") is applying following optimization:" + << "type [" << optimizer::memory_optimizer::OptimizationTypeToString(apply_context->type) + << "], request count [" << apply_context->requested_count << "]"; + if (apply_context->type == optimizer::memory_optimizer::OptimizationType::Recompute || + apply_context->type == optimizer::memory_optimizer::OptimizationType::RecomputeWithCompromise) { + optimizer::memory_optimizer::NodeRecomputePlan* recompute_plan = + dynamic_cast(node_plan.get()); + ORT_ENFORCE(recompute_plan != nullptr); + ORT_ENFORCE(CreateRecomputeGraph(graph, recompute_plan->GetNodesInTopoOrder(), replacement_node_ptr).IsOK()); } else { - ORT_THROW("unsupported optimization type found: " + UserConfigToString(user_config)); + ORT_THROW("unsupported optimization type found."); } ORT_ENFORCE(replacement_node_ptr); @@ -278,60 +129,44 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve LOGS(logger, VERBOSE) << "Memory optimization config: " << optimizer_config_ << ", probe level: " << static_cast(recompute_probe_level_); - InlinedHashMap> fw_op_output_arg_used_map; - InlinedHashMap node_index_to_its_order_in_topological_sort_map; - int64_t boundary_op_order_in_topological_sort = - PrepareForTransformation(graph, fw_op_output_arg_used_map, - node_index_to_its_order_in_topological_sort_map); - if (boundary_op_order_in_topological_sort < 0) { - LOGS(logger, VERBOSE) << "No boundary op found. Skip memory optimization."; + if (pattern_subgraph_to_user_optimizer_config_map_.empty()) { + LOGS(logger, VERBOSE) << "No optimization pattern is specified, skip memory optimization."; return Status::OK(); } + ptrdiff_t yield_op_order_in_topological_sort; InlinedHashMap> candidate_output_args_map; - ORT_RETURN_IF_ERROR(GetStashedActivationCandidates(graph, fw_op_output_arg_used_map, candidate_output_args_map, - logger)); - - SubGraphStores recompute_subgraph_stores; - SubGraphStores recompute_with_compromise_subgraph_stores; - GraphViewer graph_viewer(graph); - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); + InlinedHashMap node_index_to_its_order_in_topological_sort_map; // The first pass - find the candidate subgraphs. - for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { - Node* p_node = graph.GetNode(node_ids[i]); - if (p_node == nullptr) { - continue; - } - - if (candidate_output_args_map.find(p_node) == candidate_output_args_map.end()) { - continue; - } + GraphViewer graph_viewer(graph); + optimizer::memory_optimizer::MemoryOptimizationPlanner memory_opt_planner; + ORT_ENFORCE(optimizer::memory_optimizer::FindORTModuleMemoryOpportunity( + graph_viewer, + recompute_probe_level_, + logger, + node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, + candidate_output_args_map, + memory_opt_planner) + .IsOK()); - bool can_compromise_stashed_activation = false; - CheckNodeForRecompute(*p_node, fw_op_output_arg_used_map, - node_index_to_its_order_in_topological_sort_map, - candidate_output_args_map, - recompute_subgraph_stores, logger, false, - can_compromise_stashed_activation); - - if (can_compromise_stashed_activation) { - LOGS(logger, VERBOSE) << "Searching Node " << p_node->Name() << "(" << p_node->OpType() - << ") for compromised recompute"; - // If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist - // during backward pass, then we can try to compromise the assumption. - CheckNodeForRecompute(*p_node, fw_op_output_arg_used_map, node_index_to_its_order_in_topological_sort_map, - candidate_output_args_map, - recompute_with_compromise_subgraph_stores, logger, true, - can_compromise_stashed_activation); - } - } + // Finalize the plan according to user config, + // then create a ClusterApplyContext for each unique cluster (having the same node pattern) + InlinedHashMap> + node_to_opt_plan_map; + optimizer::memory_optimizer::NodeToClusterApplyContextMap node_to_apply_context_map; + ORT_ENFORCE(memory_opt_planner.FinalizeNodePlansFromUserConfig(pattern_subgraph_to_user_optimizer_config_map_, + node_to_opt_plan_map, + node_to_apply_context_map) + .IsOK()); // The second pass - apply the transformation. // Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { @@ -339,374 +174,40 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve } bool has_been_modified = false; - if (recompute_subgraph_stores.ContainsSubGraphInstance(p_node)) { + if (node_to_opt_plan_map.find(p_node) != node_to_opt_plan_map.end()) { has_been_modified = ModifyGraph(graph, node_index_to_its_order_in_topological_sort_map, candidate_output_args_map, logger, - boundary_op_order_in_topological_sort, - recompute_subgraph_stores, p_node); - } - - // If there are other recompute plan for this node, we skip them because the graph is already modified. - if (!has_been_modified && recompute_with_compromise_subgraph_stores.ContainsSubGraphInstance(p_node)) { - has_been_modified = ModifyGraph(graph, node_index_to_its_order_in_topological_sort_map, - candidate_output_args_map, logger, - boundary_op_order_in_topological_sort, - recompute_with_compromise_subgraph_stores, p_node); + yield_op_order_in_topological_sort, + p_node, + node_to_opt_plan_map[p_node], + node_to_apply_context_map[p_node]); } modified = modified || has_been_modified; } - PrintSummary(recompute_subgraph_stores, recompute_with_compromise_subgraph_stores, logger); + PrintSummary(memory_opt_planner, node_to_apply_context_map, logger); return Status::OK(); } -void MemoryOptimizer::NodesInTopoOrderToString(const InlinedVector& nodes_in_topological_order, - std::string& subgraph_string_representation, - std::string& log_info) const { - std::ostringstream oss; - std::ostringstream subgraph_string_representation_oss; - size_t node_count = nodes_in_topological_order.size(); - for (size_t i = 0; i < node_count; ++i) { - if (i < node_count - 1) { // Ignore the last node. - oss << "(name:" << nodes_in_topological_order[i]->Name() << ", type:" << nodes_in_topological_order[i]->OpType() - << "),"; - } - - subgraph_string_representation_oss << nodes_in_topological_order[i]->OpType() << "+"; - } - - subgraph_string_representation = subgraph_string_representation_oss.str(); - log_info = oss.str(); - if (log_info.size() > 0) { - log_info = " with its precedent nodes: " + log_info; - } -} - -std::string MemoryOptimizer::UserConfigToString(const UserConfig& config) const { - std::string type_str; - switch (config.type) { - case OptimizationType::None: { - type_str = "Disabled"; - } break; - case OptimizationType::Recompute: { - type_str = "Recomputed"; - } break; - default: { - type_str = "Unknown"; - } break; - } - return type_str; -} - -void MemoryOptimizer::PrintSummary(const SubGraphStores& recompute_stores, - const SubGraphStores& recompute_with_compromise_stores, +void MemoryOptimizer::PrintSummary(const optimizer::memory_optimizer::MemoryOptimizationPlanner& memory_opt_planner, + const InlinedHashMap< + const Node*, + std::shared_ptr>& + node_to_apply_contexts_map, const logging::Logger& logger) const { - if (recompute_stores.SubGraphDescCount() == 0 && recompute_with_compromise_stores.SubGraphDescCount() == 0) { - return; - } - - std::ostringstream summary; - summary << "\nMemoryOptimizer Summary:\n"; - summary << "\tUser config:\n\t" << optimizer_config_ << "\n"; - summary << "\t=================================\n"; - - auto print_info_from_stores = [&summary, this](std::string store_name, const SubGraphStores& stores) { - summary << "\t########" << store_name << "########\n"; - for (auto subgraph_it = stores.subgraph_descs.begin(); subgraph_it != stores.subgraph_descs.end(); - ++subgraph_it) { - std::string freq_info; - if (subgraph_it->second.user_optimizer_config.type != OptimizationType::None) - freq_info = " (requested_count=" + std::to_string(subgraph_it->second.user_optimizer_config.requested_count) + - ", actual applied_count=" + - std::to_string(subgraph_it->second.applied_count) + ")"; - summary << "\tSubgraph: " << subgraph_it->first << "\n" - << "\t\tOptimizationType: " - << UserConfigToString(subgraph_it->second.user_optimizer_config) << freq_info << "\n" - << "\t\tPatterns: \n"; - for (auto shape_stat_it = subgraph_it->second.shape_str_frequency.begin(); - shape_stat_it != subgraph_it->second.shape_str_frequency.end(); - ++shape_stat_it) { - summary << "\t\t\tPatternShape:" << shape_stat_it->first << "\tFrequency:" << shape_stat_it->second << "\n"; - } - summary << "\t--------------------------------\n"; - } - summary << "\t=================================\n"; - }; - - print_info_from_stores("Recompute", recompute_stores); - print_info_from_stores("RecomputeWithCompromise", recompute_with_compromise_stores); - - LOGS(logger, INFO) << summary.str() << "\n"; + std::vector> records_grouped_by_node_cluster_id; + optimizer::memory_optimizer::GetMemoryRecordsGroupedByNodeClusterId(memory_opt_planner, + node_to_apply_contexts_map, + records_grouped_by_node_cluster_id); + LOGS(logger, INFO) << SerializeMemoryRecords(records_grouped_by_node_cluster_id, optimizer_config_) << "\n"; } /****************************************************** ** Recompute related function implementation starts ** ******************************************************/ -void MemoryOptimizer::RegisterAllowedRecomputeOps() { - if (static_cast(recompute_probe_level_) >= static_cast(ProbeLevel::Basic)) { - recomputable_op_type_to_input_arg_index_map_.insert({ - // Binary elementwise - {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, - {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, - - // Data layout - /// The shape input is trivial whether it exists or not in backward. - {"Reshape", AllowedRecomputeNodeConfig{{0}}}, - {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, - {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, - - // Unary elementwise - /// The ratio and mode input are trivial whether they exist or not in backward - {"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}}, - /// The axis input is trivial whether it exists or not in backward - {"CumSum", AllowedRecomputeNodeConfig{{0}}}, - {"Dropout", AllowedRecomputeNodeConfig{{0}}}, - {"Gelu", AllowedRecomputeNodeConfig{{0}}}, - {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, - - // Ternary elementwise - {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, - - // Data copy - {"Tile", AllowedRecomputeNodeConfig{{0}}}, - {"Cast", AllowedRecomputeNodeConfig{{0}}}, - }); - } - - if (static_cast(recompute_probe_level_) >= static_cast(ProbeLevel::Advanced)) { - recomputable_op_type_to_input_arg_index_map_.insert({ - {"MatMul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Softmax", AllowedRecomputeNodeConfig{{0}}}, - {"BiasSoftmax", AllowedRecomputeNodeConfig{{0, 1}}}, - {"BiasSoftmaxDropout", AllowedRecomputeNodeConfig{{0, 1}}}, - }); - } -} - -Status MemoryOptimizer::SelectRecomputeSubgraph(const Node& entry_node, - const InlinedVector& node_output_index_candidates, - const ActivationUsedMap& fw_op_output_arg_used_map, - const InlinedHashMap& - node_index_to_its_order_in_topological_sort_map, - InlinedVector& nodes, - const logging::Logger& logger, - bool compromise_stashed_activation, - bool& can_compromise_stashed_activation) const { - can_compromise_stashed_activation = false; - - LOGS(logger, VERBOSE) << "Enter SelectRecomputeSubgraph for Node " << entry_node.Name() << "(" << entry_node.OpType() << ")"; - nodes.clear(); - - std::deque q; - for (auto output_index : node_output_index_candidates) { - q.push_back(NodeOutputPort(&entry_node, static_cast(output_index))); - } - - bool early_stop = false; - std::set visited_output_arg_set; - std::set visited_node_set; - - // For the initial activations in queue, they are stashed ones, so we do differently when scan the queue for them. - bool is_first_queue_scan = true; - while (nodes.size() < MAXIMUM_RECOMPUTE_NODE_COUNT && !q.empty() && !early_stop) { - // Loop all candidate NodeOutputPort, and find the next layer of input nodes. - size_t current_queue_size = q.size(); - for (size_t i = 0; i < current_queue_size; ++i) { - NodeOutputPort p = q.front(); - q.pop_front(); - const Node* curr_node = p.first; - - // Skip if the node output is already visited. - if (std::find(visited_output_arg_set.begin(), visited_output_arg_set.end(), p) != - visited_output_arg_set.end()) { - continue; - } - - visited_output_arg_set.insert({p}); - - // If the node already visited by from it's other output index, skip it. - if (visited_node_set.find(curr_node) != visited_node_set.end()) { - continue; - } - - visited_node_set.insert(curr_node); - - // Bottom-up search rules. - // If current op is entry output node (that generates stashed activations): - // 1. If the op is not in recomputable_op_type_to_input_arg_index_map_, skip it. - // Otherwise: - // If current op is in allowed list, check its input args, and append the producers' NodeOutputPorts to next_q. - // If current op is NOT in allowed list: - // 1). the output does not exist in backward, we cannot find a good solution for so, search terminates. - // 2). the output is used in backward, we don't need trace back further, continue searching. - auto op_recompute_config_it = recomputable_op_type_to_input_arg_index_map_.find(curr_node->OpType()); - auto cur_output_arg_name = curr_node->OutputDefs()[p.second]->Name(); - if (is_first_queue_scan) { - // We handle the entry node outputs differently because, we don't want this case falls into and succeed one of - // the checks in the other branch - // 1. "op is not in recompute op list, but its output is used in backward" - // 2. "op is in recompute op list, but its output is used in backward" - // (either of the above checks is true for entry node outputs) - if (op_recompute_config_it == recomputable_op_type_to_input_arg_index_map_.end()) { - early_stop = true; - LOGS(logger, VERBOSE) << "Entry Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** " - << "in recompute op list, search terminates."; - break; - } - } else { - if (op_recompute_config_it == recomputable_op_type_to_input_arg_index_map_.end()) { - if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { - LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in " - << "recompute op list, but its output [" << cur_output_arg_name << "] is used in " - << "backward, we don't need trace bottom-up further. Entry node: " - << entry_node.Name() << "(" << entry_node.OpType() << ")"; - continue; - } else { - early_stop = true; - LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in " - << "recompute op list, and its output [" << cur_output_arg_name - << "] does not exist in backward, search terminates. Entry node: " - << entry_node.Name() << "(" << entry_node.OpType() << ")"; - break; - } - } - - if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { - LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") " - << "is in recompute op list, while its output [" << cur_output_arg_name - << "] is used in backward, we don't need trace bottom-up further. Entry node: " - << entry_node.Name() << "(" << entry_node.OpType() << ")"; - continue; - } - } - - // Append node to the selected graph. - if (std::find(nodes.begin(), nodes.end(), curr_node) == nodes.end()) { - nodes.push_back(curr_node); - LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() - << ") is added in selected subgraph "; - } - - // This check is not matured now, subject to be changed. - float ratio = InputOutputSizeRatio(curr_node); - float is_current_node_compromisable = (ratio < 1.f); - can_compromise_stashed_activation = can_compromise_stashed_activation || is_current_node_compromisable; - if (is_current_node_compromisable) { - LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() - << ") has input/output size " << ratio << " < 1.f, can compromise stashed activation"; - } - - if (is_current_node_compromisable && compromise_stashed_activation) { - LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is in " - << "recompute op list, and its output [" << cur_output_arg_name - << "] does not exist in backward, while it meet compromised check, we don't need trace " - << "bottom-up further."; - continue; - } - - // Iterate all input nodes according to allowed input arg index of the entry node. - const auto& input_arg_indices = op_recompute_config_it->second.input_arg_indices; - for (auto it = curr_node->InputEdgesBegin(), end = curr_node->InputEdgesEnd(); it != end; ++it) { - const Node::EdgeEnd& input_edge = *it; - const auto& parent_node = input_edge.GetNode(); - const auto parent_node_output_index = input_edge.GetSrcArgIndex(); - const auto current_node_input_index = input_edge.GetDstArgIndex(); - if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != - input_arg_indices.end()) { - NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); - - LOGS(logger, VERBOSE) << "Node " << parent_node.Name() << "(" << parent_node.OpType() << ")'s " - << parent_node_output_index - << "th output [" << parent_node.OutputDefs()[parent_node_output_index]->Name() - << "] is added in recompute search list "; - - q.push_back(next_p); - } - } - } - // After handle all entry node outputs, we set the flag to false. - is_first_queue_scan = false; - } - - // If input args are not found in bw, but op count exceed MAXIMUM_RECOMPUTE_NODE_COUNT, skip recompute. - if (!q.empty() || early_stop) { - LOGS(logger, VERBOSE) << "Fail to find a solution for recompute: current node count is " << nodes.size() - << ", queue size: " << q.size() << ", early stop: " << early_stop; - nodes.clear(); - } else { - // Re-order the nodes in topological order. - std::sort(nodes.begin(), nodes.end(), - [&node_index_to_its_order_in_topological_sort_map](const Node*& lhs, const Node*& rhs) { - return node_index_to_its_order_in_topological_sort_map.at(lhs->Index()) < - node_index_to_its_order_in_topological_sort_map.at(rhs->Index()); - }); - } - return Status::OK(); -} - -void MemoryOptimizer::CheckNodeForRecompute(const Node& node, - const ActivationUsedMap& fw_op_output_arg_used_map, - const InlinedHashMap& - node_index_to_its_order_in_topological_sort_map, - const InlinedHashMap>& - candidate_output_args_map, - SubGraphStores& subgraph_stores, - const logging::Logger& logger, - bool compromise_stashed_activation, - bool& can_compromise_stashed_activation) const { - if (recomputable_op_type_to_input_arg_index_map_.find(node.OpType()) == - recomputable_op_type_to_input_arg_index_map_.end()) { - return; - } - - InlinedVector nodes_in_topological_order; - ORT_ENFORCE(SelectRecomputeSubgraph(node, candidate_output_args_map.at(&node), - fw_op_output_arg_used_map, - node_index_to_its_order_in_topological_sort_map, - nodes_in_topological_order, logger, - compromise_stashed_activation, - can_compromise_stashed_activation) - .IsOK()); - if (nodes_in_topological_order.size() == 0) { - return; - } - - std::string subgraph_str_representation, log_info; - NodesInTopoOrderToString(nodes_in_topological_order, subgraph_str_representation, log_info); - LOGS(logger, VERBOSE) << "Node " << node.Name() << "(" << node.OpType() << ") can be recomputed" << log_info; - - // Update the subgraph optimization config map - key is the subgraph string representation, value is user config. - UserConfig user_config{OptimizationType::None, 0}; - if (pattern_subgraph_to_user_optimizer_config_map_.find(subgraph_str_representation) != - pattern_subgraph_to_user_optimizer_config_map_.end()) { - user_config = pattern_subgraph_to_user_optimizer_config_map_.at(subgraph_str_representation); - } - - SubGraphDesc& subgraph_desc = - subgraph_stores.Contains(subgraph_str_representation) - ? subgraph_stores.GetSubGraphDesc(subgraph_str_representation) - : subgraph_stores.CreateSubGraphDesc(subgraph_str_representation, user_config); - - subgraph_desc.total_frequency += 1; - - // Update the subgraph frequency map - key is the subgraph string representation, value is number of appearances. - for (size_t output_index : candidate_output_args_map.at(&node)) { - auto shape_str = TensorShapeProtoToString(node.OutputDefs()[output_index]->Shape()); - subgraph_desc.shape_str_frequency[shape_str]++; - } - - subgraph_stores.AddSubGraphInstance(&node, nodes_in_topological_order, subgraph_desc); - - return; -} - Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, const InlinedVector& nodes_in_topological_order, Node*& new_output_node_ptr) const { @@ -716,8 +217,8 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, // Check whether the node has been recomputed/offloaded or not. Simply check the existence of the first output // of the node has its corresponding recompute name or not. - // TODO: if there is more optimization types like offload added, we will add corresponding check whether the outputs - // already be offloaded or not. + // TODO: if there is more optimization types like offload added, we will add a corresponding check + // whether the outputs already be offloaded or not. if (graph.GetNodeArg(graph_utils::RecomputeName(node_to_duplicate->MutableOutputDefs()[0]->Name())) != nullptr) { continue; } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer.h b/orttraining/orttraining/core/optimizer/memory_optimizer.h index 1d21c9143f62f..13eb4cdb242f4 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer.h @@ -2,163 +2,39 @@ // Licensed under the MIT License. #pragma once -#include + #include "core/common/inlined_containers.h" #include "core/common/string_utils.h" #include "core/optimizer/graph_transformer.h" +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h" +#include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" +#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" namespace onnxruntime { /** @Class MemoryOptimizer -Find recomputable subgraphs and enable according to user configs. +(TODO) move to orttraining/orttraining/core/optimizer/memory_optimizer/ folder. + +Find recompute subgraphs and enable them according to user configs. The way we collect subgraphs +(in orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h) in brief is: +1. Find all nodes that generate stashed activations. +2. For each node, check it data type is supported to recompute + a. If yes, add it in the subgraph, and append its input in the queue to scan next; + b. otherwise, stop collecting and return the subgraph (could be empty). +3. Pick up the input node from the queue, and do 2 again. The process ends when the queue is empty or 2.b happens. +4. Clone the recomputable subgraphs with lower node priority (to execute) and insert them back to the original graph. */ class MemoryOptimizer : public GraphTransformer { private: - using NodeOutputPort = std::pair; - using ActivationUsedMap = InlinedHashMap>; - - /** - * @brief Level to control allowed operations during subgraph detecting. - * Level 0: only allow cheap-to-compute operations. - * Level 1: allow more expensive operations. - */ - enum class ProbeLevel { - Basic = 0, - Advanced = 1, - LevelMax = 2, - }; - - /** - * @brief Type of memory reduction techniques. - */ - enum class OptimizationType { - None = 0, // Disabled. - Recompute = 1, - TypeMax = 2, - }; - - /** - * @brief Type of user config. - * type: type of memory reduction techniques. - * requested_count: the number of occurrences of a subgraph pattern for alleviation. -1 means apply all. - * One example: if a subgraph pattern is found 3 times, and requested_count is set 2, then the 1st and 2nd subgraph - * in topological order will be applied for alleviation. This is useful to avoid alleviating more memory than - * needed. - */ - struct UserConfig { - OptimizationType type; - int requested_count; - }; - - /** - * @brief Struct to store properties of a specific subgraph. - */ - struct SubGraphDesc { - SubGraphDesc() = default; - - // A string to represent the subgraph, used as a unique "ID" for a unique subgraph. - std::string subgraph_representative_str; - - InlinedHashMap shape_str_frequency; // shape string to frequency - UserConfig user_optimizer_config; - int total_frequency{0}; // The occurrence of this subgraph pattern in the graph. - - int applied_count{0}; // The number of times this subgraph pattern has been really applied in this transformer. - int skip_count{0}; // The number of times this subgraph instance has been skipped in reversed topological order. - float saving_ratio{1.0f}; // For compromised memory saving, the ratio of memory saving. - }; - - /** - * @brief A struct to maintain the information of target subgraphs to optimize. - * Imagine we loop all nodes finding recomputable/offload-able subgraphs, we want to store them first. - * Afterwards, we optionally pick up some of them to apply optimization according to user configs. - * - * subgraph_descs is a map from subgraph string representation to its subgraph related configurations. - * - * _optimization_target_graphs_ is a map from activation producer node pointers to its target optimization subgraph - * nodes. For example, if a subgraph Cast+Gelu can be recomputed, we may have a map like: - * key: node pointer of stashed activation producer Gelu; value: node vector {Cast, Gelu,}. - * - * When we AddSubGraphInstance, we must provider its corresponding subgraph desc in the parameter. - * Then we can know for each subgraph instance, what's the subgraph str representation, and what's the optimization - * config. - */ - struct SubGraphStores { - /********************************** - ** subgraph desc section starts ** - **********************************/ - - size_t SubGraphDescCount() const { - return subgraph_descs.size(); - } - - bool Contains(std::string_view subgraph_str) const { - return subgraph_descs.find(subgraph_str) != subgraph_descs.end(); - } - - SubGraphDesc& GetSubGraphDesc(std::string_view subgraph_string) { - ORT_ENFORCE(Contains(subgraph_string), "Subgraph string not found.", subgraph_string); - return subgraph_descs.at(subgraph_string); - } - - SubGraphDesc& CreateSubGraphDesc(const std::string& subgraph_string, - UserConfig& config) { - ORT_ENFORCE(!Contains(subgraph_string), "Subgraph string already exists.", subgraph_string); - subgraph_descs[subgraph_string].user_optimizer_config = config; - subgraph_descs[subgraph_string].subgraph_representative_str = subgraph_string; - return subgraph_descs[subgraph_string]; - } - - /********************************************************************** - ** subgraph desc section ends, and subgraph instance section starts. ** - ***********************************************************************/ - - // Pair of . - using GraphInstanceInfo = std::pair, std::string>; - - void AddSubGraphInstance(const Node* node, - const InlinedVector& nodes_in_topological_order, - const SubGraphDesc& subgraph_desc) { - ORT_ENFORCE(_optimization_target_graphs_.find(node) == _optimization_target_graphs_.end()); - _optimization_target_graphs_[node] = std::make_pair(nodes_in_topological_order, - subgraph_desc.subgraph_representative_str); - } - - bool ContainsSubGraphInstance(const Node* node) const { - return _optimization_target_graphs_.find(node) != _optimization_target_graphs_.end(); - } - - GraphInstanceInfo& GetSubGraphInstance(const Node* node) { - ORT_ENFORCE(_optimization_target_graphs_.find(node) != _optimization_target_graphs_.end()); - return _optimization_target_graphs_[node]; - } - - /*********************************** - ** subgraph instance section ends ** - ***********************************/ - - InlinedHashMap subgraph_descs; - InlinedHashMap _optimization_target_graphs_; - }; - - /** - * @brief Used to define per-op recompute config. - * - */ - struct AllowedRecomputeNodeConfig { - InlinedVector input_arg_indices; // input index to iterate further (bottom up) - }; - public: - MemoryOptimizer(const std::string& enable_memory_optimizer, const std::string& level) + MemoryOptimizer(const std::string& memory_optimizer_config, const std::string& level) : GraphTransformer("MemoryOptimizer") { // Parse user defined configs. - ORT_ENFORCE(ParseConfigFromString(enable_memory_optimizer, level).IsOK()); - - RegisterAllowedRecomputeOps(); + ORT_ENFORCE(ParseConfigFromString(memory_optimizer_config, level).IsOK()); } Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; @@ -166,35 +42,7 @@ class MemoryOptimizer : public GraphTransformer { bool ShouldOnlyApplyOnce() const override { return true; } private: - Status ParseConfigFromString(const std::string& enable_memory_optimizer, const std::string& level); - - /** - * @brief Prepare info including activation usage, node usage in fw and bw. - * - * @param graph Graph to iterate. - * @param fw_op_output_arg_used_map Collected activation usage mapping. - * - key: node arg name - * - value: a pair of bool, representing whether the activation is used by forward nodes or by backward nodes. - * @return int64_t value The boundary op (for example YieldOp) order in topological order. If no boundary op found, - * return -1; - */ - int64_t PrepareForTransformation(const Graph& graph, - ActivationUsedMap& fw_op_output_arg_used_map, - InlinedHashMap& - node_index_to_its_order_in_topological_sort_map) const; - /** - * @brief Find all stashed activations, e.g. activations used by forward operators and backward operators. - * - * @param graph Graph to iterate. - * @param fw_op_output_arg_used_map Activation usage mapping. - * @param candidate_output_args_map Candidate activations, which are consumed by both fw and bw ops. - * @return Status - */ - Status GetStashedActivationCandidates( - const Graph& graph, - const InlinedHashMap>& fw_op_output_arg_used_map, - InlinedHashMap>& candidate_output_args_map, - const logging::Logger& logger) const; + Status ParseConfigFromString(const std::string& memory_optimizer_config, const std::string& level); /** * @brief Apply graph modifications based on user configs. @@ -212,28 +60,15 @@ class MemoryOptimizer : public GraphTransformer { * @return false */ bool ModifyGraph(Graph& graph, - const InlinedHashMap& node_index_to_its_order_in_topological_sort_map, - const InlinedHashMap>& candidate_output_args_map, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const InlinedHashMap>& + candidate_output_args_map, const logging::Logger& logger, - int64_t boundary_op_order_in_topological_sort, - SubGraphStores& subgraph_stores, - Node* node) const; - - /** - * @brief Convert the recompute subgraph to its string representation. - * - * @param nodes_in_topological_order The subgraph nodes in topological order. - * @param subgraph_string_representation Returns subgraph string representation. - * @param log_info Returns log info for users. - */ - void NodesInTopoOrderToString(const InlinedVector& nodes_in_topological_order, - std::string& subgraph_string_representation, - std::string& log_info) const; - - /** - * @brief Convert optimization type to string. - */ - std::string UserConfigToString(const UserConfig& config) const; + ptrdiff_t boundary_op_order_in_topological_sort, + Node* node, + std::shared_ptr& node_plan, + std::shared_ptr& apply_context) const; /** * @brief Summarize transformation details. @@ -241,72 +76,16 @@ class MemoryOptimizer : public GraphTransformer { * @param stashed_activation_statistics statistics around stashed activation memory saving. * @return void */ - void PrintSummary(const SubGraphStores& recompute_stores, - const SubGraphStores& recompute_with_compromise_stores, + void PrintSummary(const optimizer::memory_optimizer::MemoryOptimizationPlanner& mem_opt_stats, + const InlinedHashMap>& + node_to_apply_contexts_map, const logging::Logger& logger) const; /************************************************** ** Recompute related function definition starts ** *************************************************/ - void RegisterAllowedRecomputeOps(); - - /** - * @brief Find recomputable subgraphs (has at least one nodes, at most MAXIMUM_RECOMPUTE_NODE_COUNT nodes). - * - * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. - * @param node_output_index_candidates Candidate output indices of "node", which are consumed by both fw and bw ops. - * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. - * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. - * Used to re-order the collected subgraph nodes. - * @param nodes_in_topological_order Collected vector of nodes of found subgraph, in the order of the topological - * sorted. - * @param logger Logger. - * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a - * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the - * size of stashed activation. - * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a - * compromised subgraph. - * @return Status - */ - Status SelectRecomputeSubgraph(const Node& node, - const InlinedVector& node_output_index_candidates, - const ActivationUsedMap& fw_op_output_arg_used_map, - const InlinedHashMap& - node_index_to_its_order_in_topological_sort_map, - InlinedVector& nodes_in_topological_order, - const logging::Logger& logger, - bool compromise_stashed_activation, - bool& can_compromise_stashed_activation) const; - - /** - * @brief For the node producing stashed activation, check whether a recomputable subgraph can be found or not. - * - * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. - * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. - * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. - * Used to re-order the collected subgraph nodes. - * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and - * bw ops. - * @param subgraph_stores A store to maintain all found subgraphs. - * @param logger Logger. - * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a - * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the - * size of stashed activation. - * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a - * compromised subgraph. - */ - void CheckNodeForRecompute(const Node& node, - const ActivationUsedMap& fw_op_output_arg_used_map, - const InlinedHashMap& - node_index_to_its_order_in_topological_sort_map, - const InlinedHashMap>& - candidate_output_args_map, - SubGraphStores& subgraph_stores, - const logging::Logger& logger, - bool compromise_stashed_activation, - bool& can_compromise_stashed_activation) const; - /** * @brief Duplicate nodes to create a recompute subgraph. * @@ -323,12 +102,10 @@ class MemoryOptimizer : public GraphTransformer { ** Recompute related function definition ends ** *************************************************/ - // The op types that are supported predefined. - InlinedHashMap recomputable_op_type_to_input_arg_index_map_; // User enabled map of the subgraph string representation to the alleviation type. - InlinedHashMap pattern_subgraph_to_user_optimizer_config_map_; + InlinedHashMap pattern_subgraph_to_user_optimizer_config_map_; std::string optimizer_config_; - ProbeLevel recompute_probe_level_; + optimizer::memory_optimizer::ProbeLevel recompute_probe_level_; }; } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc new file mode 100644 index 0000000000000..2291d7e4f37a6 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.cc @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph_viewer.h" +#include "core/framework/tensorprotoutils.h" + +#include "core/common/string_utils.h" + +namespace onnxruntime::optimizer::memory_optimizer { + +namespace { + +constexpr const char empty_dim_param_placeholder[] = "empty_dim_param"; +static size_t index_empty_dim = 0; + +bool TensorShapeProtoToDimParamVector(const ONNX_NAMESPACE::TensorShapeProto* shape, + std::vector& dim_params) { + bool has_unknown_dim = false; + for (int dim_index = 0; dim_index < shape->dim_size(); dim_index++) { + auto dim = shape->dim(dim_index); + if (utils::HasDimValue(dim)) { + dim_params.push_back(std::to_string(dim.dim_value())); + } else { + std::string trimmed_dim_param = utils::TrimString(dim.dim_param()); + if (trimmed_dim_param.empty()) { + has_unknown_dim = true; + dim_params.push_back(empty_dim_param_placeholder + std::to_string(index_empty_dim++)); + } else { + dim_params.push_back(trimmed_dim_param); + } + } + } + + if (shape->dim_size() == 0) { + dim_params.push_back("(1)"); // Scalar + } + + return has_unknown_dim; +} + +bool HasUnknowDimension(const ONNX_NAMESPACE::TensorShapeProto* shape) { + if (shape == nullptr) { + return true; + } + + std::vector dim_params; + return TensorShapeProtoToDimParamVector(shape, dim_params); +} + +std::string TensorShapeProtoToString(const ONNX_NAMESPACE::TensorShapeProto* shape) { + if (shape == nullptr) { + return "unknown"; + } + + std::vector dim_params; + TensorShapeProtoToDimParamVector(shape, dim_params); + + std::ostringstream oss; + oss << "("; + for (auto it = dim_params.begin(); it != dim_params.end(); ++it) { + oss << "(" << *it << ")"; + if (it != (dim_params.end() - 1)) { + oss << "*"; + } + } + oss << ")"; + + return oss.str(); +} + +} // namespace + +std::string GetTensorElemCountInSymbolicString(const Node* node, size_t output_index) { + const auto& output_def = node->OutputDefs()[output_index]; + const auto shape = output_def->Shape(); + + std::string shape_str = TensorShapeProtoToString(shape); + + // If the output shape contains unknown dimension, we try to get the shape from input. + // though the input shape might be different, but its elem size and count should be the same + // with the output. + if (node->OpType() == "Reshape" && HasUnknowDimension(shape) && + !HasUnknowDimension(node->InputDefs()[0]->Shape())) { + shape_str = TensorShapeProtoToString(node->InputDefs()[0]->Shape()); + } + + return shape_str; +} + +std::string OptimizationTypeToString(OptimizationType type) { + switch (type) { + case OptimizationType::None: + return "None"; + case OptimizationType::Recompute: + return "Recompute"; + case OptimizationType::RecomputeWithCompromise: + return "RecomputeWithCompromise"; + default: + ORT_THROW("Unknown optimization type."); + } +} + +int ParseIntValueFromString(std::string_view str) { + int int_value = 0; + auto result = std::from_chars(str.data(), str.data() + str.size(), int_value); + ORT_ENFORCE(result.ec != std::errc::invalid_argument, "Fail to convert to int from string: ", str); + return int_value; +} + +Status ParseConfigFromString(std::string_view memory_optimization_config, + InlinedHashMap& cluster_id_to_config_map) { + if (!memory_optimization_config.empty()) { + const auto user_config_strs = utils::SplitString(memory_optimization_config, ","); + for (const auto& user_config_str : user_config_strs) { + const auto user_config = utils::SplitString(user_config_str, ":"); + ORT_RETURN_IF_NOT(user_config.size() == 3, + "User config should be in format of SubgraphStr:OptimizationType:RequestApplyCount."); + + const std::string subgraph_string_representation(user_config[0]); + int optimization_type_int = ParseIntValueFromString(user_config[1]); + int requested_apply_count = ParseIntValueFromString(user_config[2]); + ORT_RETURN_IF_NOT(optimization_type_int < + static_cast(OptimizationType::TypeMax) && + optimization_type_int >= 0, + "Invalid optimization type specified for subgraph: ", + subgraph_string_representation); + + ORT_RETURN_IF_NOT(requested_apply_count == -1 || requested_apply_count >= 0, + "Invalid requested_apply_count specified for subgraph: ", requested_apply_count); + + // At this point, subgraph_string_representation is a pattern graph string representation. + // If duplicated subgraph_string_representation is found in user config, the last one will be used. + cluster_id_to_config_map[subgraph_string_representation] = UserConfig{ + static_cast(optimization_type_int), + requested_apply_count}; + } + } + + return Status::OK(); +} + +} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/common.h b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h new file mode 100644 index 0000000000000..85e2bf4f5d683 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/common.h @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/logging/logging.h" +#include "core/common/inlined_containers_fwd.h" +#include "core/graph/basic_types.h" +#include "core/framework/data_types.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime::optimizer::memory_optimizer { + +// Uncomment for debugging Memory optimizer (MO). +// #define MO_NEED_LOG_DEBUG_INFO 1 + +#ifndef MO_LOG_DEBUG_INFO +#ifdef MO_NEED_LOG_DEBUG_INFO +#define MO_LOG_DEBUG_INFO(logger, message) LOGS(logger, WARNING) << message +#else +#define MO_LOG_DEBUG_INFO(logger, message) \ + ORT_UNUSED_PARAMETER(logger); \ + do { \ + } while (0) +#endif +#endif + +using NodeOutputPort = std::pair; +using ActivationUsedMap = InlinedHashMap>; + +/** + * @brief Type of memory reduction techniques. + */ +enum class OptimizationType { + None = 0, // Disabled. + Recompute = 1, + RecomputeWithCompromise = 2, + TypeMax = 3, +}; + +std::string OptimizationTypeToString(OptimizationType type); + +/** + * @brief Type of user config. + * type: type of memory reduction techniques. + * requested_count: the number of occurrences of a subgraph pattern for alleviation. -1 means apply all. + * One example: if a subgraph pattern is found 3 times, and requested_count is set 2, then the 1st and 2nd subgraph + * in topological order will be applied for alleviation. This is useful to avoid alleviating more memory than + * needed. + */ +struct UserConfig { + OptimizationType type; + int requested_count; +}; + +/** + * @brief Get total element count inn format of a symbolic string. + * + * @param node The node to get element count. + * @param output_index The output index of the node. + * @return std::string + */ +std::string GetTensorElemCountInSymbolicString(const Node* node, size_t output_index); + +int ParseIntValueFromString(std::string_view str); + +Status ParseConfigFromString(std::string_view memory_optimization_config, + InlinedHashMap& cluster_id_to_config_map); + +} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc new file mode 100644 index 0000000000000..60f62a9881ef4 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -0,0 +1,763 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h" +#include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" +#include "orttraining/core/optimizer/memory_optimizer/memory_insight.h" + +namespace onnxruntime::optimizer::memory_optimizer { + +// Placeholder string for table row separator, which is used to be replaced by table row separator finally. +constexpr const char kTableRowSeparator[] = "TABLE_SEPARATOR_PLACEHOLDER"; +// Placeholder string for table border, which is used to be replaced by table border finally. +constexpr const char kTableBorder[] = "TABLE_BORDER_PLACEHOLDER"; + +// The max length of the first column in the table. +constexpr const int kFirstColumnWidth = 7; +// The max length of left part (e.g. title) in the second column. +constexpr const int kTitleWidthInSecondColumn = 15; + +/** + * @brief Prepare info including activation usage, node usage in fw and bw. + * + * @param graph Graph to iterate. + * @param boundary_op_order_in_topological_sort index of the boundary op between fw and bw. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * @param fw_op_output_arg_used_map Collected activation usage mapping. + * - key: node arg name + * - value: a pair of bool, representing whether the activation is used by forward nodes or by backward nodes. + * @param is_forward_nodes Collected node is forward pass op mapping. + */ +void GetForwardOutputUsageMap(const GraphViewer& graph_viewer, + const ptrdiff_t boundary_op_order_in_topological_sort, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + ActivationUsedMap& fw_op_output_arg_used_map, + InlinedHashMap& is_forward_nodes) { + ORT_ENFORCE(boundary_op_order_in_topological_sort >= 0); + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); + is_forward_nodes.clear(); + is_forward_nodes.reserve(node_ids.size()); + + auto is_forward_pass_operator = [](ptrdiff_t op_order_in_topological_sort, + ptrdiff_t boundary_op_order_in_topological_sort) -> bool { + return op_order_in_topological_sort <= boundary_op_order_in_topological_sort; + }; + + fw_op_output_arg_used_map.clear(); + fw_op_output_arg_used_map.reserve(node_ids.size()); + for (size_t i = 0; i < node_ids.size(); ++i) { + const Node* p_node = graph_viewer.GetNode(node_ids[i]); + if (p_node == nullptr /* skip removed nodes*/) { + continue; + } + + const Node& node = *p_node; + + bool is_forward_op = is_forward_pass_operator(static_cast(i), boundary_op_order_in_topological_sort); + if (!is_forward_op) { + is_forward_nodes[p_node] = false; + continue; + } + + is_forward_nodes[p_node] = true; + + for (auto& output_arg : node.OutputDefs()) { + if (!output_arg->Exists() || output_arg->Name().empty()) { + continue; + } + + bool used_in_fw = false; + bool used_in_bw = false; + for (auto& consumer_node : graph_viewer.GetConsumerNodes(output_arg->Name())) { + ORT_ENFORCE(consumer_node != nullptr, "Consumer node should not be null."); + auto it = node_index_to_its_order_in_topological_sort_map.find(consumer_node->Index()); + ORT_ENFORCE(it != + node_index_to_its_order_in_topological_sort_map.end(), + "Consumer node should be in topological order map."); + size_t consumer_node_index_in_topological_order = it->second; + if (is_forward_pass_operator(static_cast(consumer_node_index_in_topological_order), + boundary_op_order_in_topological_sort)) { + used_in_fw = true; + } else { + used_in_bw = true; + } + } + + ORT_ENFORCE(fw_op_output_arg_used_map.find(output_arg->Name()) == fw_op_output_arg_used_map.end(), + "Duplicated output arg found named: ", output_arg->Name()); + fw_op_output_arg_used_map.insert({{output_arg->Name(), std::make_pair(used_in_fw, used_in_bw)}}); + } + } +} + +/** + * @brief Find all stashed activations, e.g. activations used by forward operators and backward operators. + * + * @param graph_viewer Graph to iterate. + * @param boundary_op_order_in_topological_sort The order of the boundary op in the topological sort. + * @param fw_op_output_arg_used_map Activation usage mapping. + * @param candidate_output_args_map Candidate activations, which are consumed by both fw and bw ops. + * @param is_forward_nodes Whether a node is a forward node. + * @param logger Logger. + * @return Status + */ + +Status GetStashedActivationCandidates(const GraphViewer& graph_viewer, + const ptrdiff_t boundary_op_order_in_topological_sort, + ActivationUsedMap& fw_op_output_arg_used_map, + InlinedHashMap>& + candidate_output_args_map, + InlinedHashMap& is_forward_nodes, + const logging::Logger& logger) { + if (boundary_op_order_in_topological_sort < 0) { + LOGS(logger, VERBOSE) << "No boundary op found. Skip memory optimization."; + return Status::OK(); + } + + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); + + InlinedHashMap node_index_to_its_order_in_topological_sort_map; + for (size_t i = 0; i < node_ids.size(); ++i) { + const Node* p_node = graph_viewer.GetNode(node_ids[i]); + if (p_node == nullptr) { /* skip removed nodes*/ + continue; + } + + node_index_to_its_order_in_topological_sort_map[p_node->Index()] = i; + } + + GetForwardOutputUsageMap(graph_viewer, boundary_op_order_in_topological_sort, + node_index_to_its_order_in_topological_sort_map, + fw_op_output_arg_used_map, + is_forward_nodes); + + for (auto& kv : fw_op_output_arg_used_map) { + // used by fw and bw, then it is a candidate. + if (kv.second.first && kv.second.second) { + const Node* n = graph_viewer.GetProducerNode(kv.first); + ORT_ENFORCE(n, "Activation should have a producer node"); + size_t k = 0; + for (k = 0; k < n->OutputDefs().size(); ++k) { + if (n->OutputDefs()[k]->Name().compare(kv.first) == 0) { + break; + } + } + + if (std::find(candidate_output_args_map[n].begin(), candidate_output_args_map[n].end(), k) != + candidate_output_args_map[n].end()) { + ORT_ENFORCE(false, "Duplicated candidate output found."); + } + + candidate_output_args_map[n].push_back(k); + LOGS(logger, VERBOSE) << "Find candidate output named [" << kv.first << "] of Node " << n->Name() << "(" + << n->OpType() << ")"; + } + } + + return Status::OK(); +} + +Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, + const ProbeLevel probe_level, + const logging::Logger& logger, + InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + ptrdiff_t& yield_op_order_in_topological_sort, + InlinedHashMap>& + candidate_output_args_map, + MemoryOptimizationPlanner& memory_opt_planner) { + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(); + + // Find boundary ops between forward and backward pass, currently, it's limited to YieldOp. + yield_op_order_in_topological_sort = -1; + for (size_t i = 0; i < node_ids.size(); ++i) { + const Node* p_node = graph_viewer.GetNode(node_ids[i]); + if (p_node == nullptr) { /* skip removed nodes*/ + continue; + } + + if (p_node->OpType() == "YieldOp") { + if (yield_op_order_in_topological_sort != -1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "There are multiple YieldOps in the graph, node: ", + p_node->Name(), " is the second one."); + } + yield_op_order_in_topological_sort = static_cast(i); + } + + node_index_to_its_order_in_topological_sort_map[p_node->Index()] = static_cast(i); + } + + ActivationUsedMap fw_op_output_arg_used_map; + + InlinedHashMap is_forward_nodes; + ORT_RETURN_IF_ERROR(GetStashedActivationCandidates(graph_viewer, + yield_op_order_in_topological_sort, + fw_op_output_arg_used_map, + candidate_output_args_map, + is_forward_nodes, + logger)); + + // The first pass - find the candidate subgraphs. + for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { + const Node* p_node = graph_viewer.GetNode(node_ids[i]); + if (p_node == nullptr) { + continue; + } + + if (candidate_output_args_map.find(p_node) == candidate_output_args_map.end()) { + continue; + } + + bool can_compromise_stashed_activation = false; + std::unique_ptr recompute_plan = + CheckNodeForRecompute(*p_node, + probe_level, + fw_op_output_arg_used_map, + node_index_to_its_order_in_topological_sort_map, + candidate_output_args_map, + logger, false, + can_compromise_stashed_activation); + if (recompute_plan != nullptr) { + memory_opt_planner.AddNodeOptimizationPlan(p_node, std::move(recompute_plan)); + } + + if (can_compromise_stashed_activation) { + LOGS(logger, VERBOSE) << "Searching Node " << p_node->Name() << "(" << p_node->OpType() + << ") for compromised recompute"; + // If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist + // during backward pass, then we can consider to recompute them. + std::unique_ptr recompute_with_compromise_plan = + CheckNodeForRecompute(*p_node, probe_level, fw_op_output_arg_used_map, + node_index_to_its_order_in_topological_sort_map, + candidate_output_args_map, + logger, true, + can_compromise_stashed_activation); + if (recompute_with_compromise_plan != nullptr) { + memory_opt_planner.AddNodeOptimizationPlan(p_node, std::move(recompute_with_compromise_plan)); + } + } + } + + return Status::OK(); +} + +void GetMemoryRecordsGroupedByNodeClusterId(const MemoryOptimizationPlanner& memory_opt_planner, + const NodeToClusterApplyContextMap& node_to_apply_contexts_map, + std::vector>& generated_records) { + // Group by node cluster id, generate memory record. + InlinedHashMap records; + const auto& node_to_optimization_plan_map = memory_opt_planner.GetNodeToOptimizationPlanMap(); + for (const auto& node_to_optimization_plan : node_to_optimization_plan_map) { + const auto& node = node_to_optimization_plan.first; + const auto& node_plans = node_to_optimization_plan.second; + const std::string node_cluster_id = memory_opt_planner.GenerateNodeClusterId(node); + + std::pair::iterator, bool> insert_result = + records.insert({node_cluster_id, MemoryRecord()}); + bool already_exist = !insert_result.second; + auto& record = insert_result.first->second; + record.freq++; + + // Collect more information for display. + for (auto& plan : node_plans) { + // Same node cluster id, plans might still have different reuse_buffer pattern, so we need to collect all of them. + if (plan->reuse_buffers.size() > 0) { + gsl::span output_indices = plan->GetActivationOutputIndices(); + for (auto output_index : output_indices) { + bool is_output_reusing_buffers = plan->reuse_buffers.find(output_index) != plan->reuse_buffers.end(); + if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise) { + if (is_output_reusing_buffers) { + record.output_port_reuse_recompute_with_compromise_count[output_index] += 1; + } + } else if (plan->GetOptimizationType() == OptimizationType::Recompute) { + if (is_output_reusing_buffers) { + record.output_port_reuse_recompute_count[output_index] += 1; + } + } + } + } + + // For other infos that are guaranteed identity by cluster id, just skip collecting. + if (already_exist) { + continue; + } + + if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise) { + record.recompute_with_compromise_subgraph_str = + dynamic_cast(plan.get())->GetNodesInTopoOrderStr(); + } else if (plan->GetOptimizationType() == OptimizationType::Recompute) { + record.recompute_subgraph_str = dynamic_cast(plan.get())->GetNodesInTopoOrderStr(); + } + + gsl::span output_indices = plan->GetActivationOutputIndices(); + for (auto output_index : output_indices) { + const auto& output_def = node->OutputDefs()[output_index]; + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto()); + ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ", + DataTypeImpl::ToString(ml_data_type)); + const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); + ORT_ENFORCE(nullptr != tensor_type_base); + MLDataType elt_type = tensor_type_base->GetElementType(); + + const auto byte_count_per_element = elt_type->Size(); + if (plan->GetOptimizationType() == OptimizationType::RecomputeWithCompromise) { + record.compromise_recomputed_outputs.emplace_back( + output_index, + GetTensorElemCountInSymbolicString(node, output_index), + byte_count_per_element, + plan->GetSaveRatio()); + + } else if (plan->GetOptimizationType() == OptimizationType::Recompute) { + record.recomputed_outputs.emplace_back(output_index, + GetTensorElemCountInSymbolicString(node, output_index), + byte_count_per_element, + plan->GetSaveRatio()); + } + } + } + } + + // Sort by feq and then by record key, to make sure the output is deterministic. + InlinedVector> freq_to_record_key; + for (const auto& p : records) { + freq_to_record_key.push_back({p.second.freq, p.first}); + } + + std::sort(freq_to_record_key.begin(), freq_to_record_key.end(), [](auto& left, auto& right) { + if (left.first == right.first) { + return left.second.compare(right.second) > 0; + } + return left.first > right.first; + }); + + for (const auto& p : freq_to_record_key) { + const std::string record_key = p.second; + generated_records.push_back({record_key, records[record_key]}); + } + + // If apply context is provided, also update the actual applied count. + if (node_to_apply_contexts_map.size() > 0) { + InlinedHashMap node_cluster_id_to_record_map; + for (auto& p : generated_records) { + node_cluster_id_to_record_map[p.first] = &p.second; + } + + for (const auto& p : node_to_apply_contexts_map) { + const auto& node = p.first; + const auto& apply_context = p.second; + std::string node_cluster_id = memory_opt_planner.GenerateNodeClusterId(node); + if (apply_context->type == OptimizationType::Recompute) { + node_cluster_id_to_record_map[node_cluster_id]->actual_recompute_count += 1; + node_cluster_id_to_record_map[node_cluster_id]->request_recompute_count = apply_context->requested_count; + } else if (apply_context->type == OptimizationType::RecomputeWithCompromise) { + node_cluster_id_to_record_map[node_cluster_id]->actual_recompute_with_compromise_count += 1; + node_cluster_id_to_record_map[node_cluster_id]->request_recompute_with_compromise_count = + apply_context->requested_count; + } else { + ORT_THROW("Unsupported optimization type found."); + } + } + } +} + +// Function declare to make it compile. +void IterateNodeOptimizationPlan(const std::shared_ptr& plan, + const InlinedHashMap>>& + node_to_optimization_plans_map, + const InlinedVector>& + current_combination, + const logging::Logger& logger, + InlinedVector>>& + all_combinations); + +/* + * Iterate from a node, generate combinations for each optimization plan for it. + */ +void IterateNode(const Node* node, + const InlinedHashMap>>& + node_to_optimization_plans_map, + const InlinedVector>& + current_combination, + const logging::Logger& logger, + InlinedVector>>& + all_combinations) { + MO_LOG_DEBUG_INFO(logger, "Enter IterateNode: " + node->Name()); + if (node_to_optimization_plans_map.find(node) == node_to_optimization_plans_map.end()) { + MO_LOG_DEBUG_INFO(logger, "Exit IterateNode since reused node don't have optimization plans: " + node->Name()); + return; + } + + for (const std::shared_ptr& plan : node_to_optimization_plans_map.at(node)) { + if (std::find(current_combination.begin(), current_combination.end(), plan) != + current_combination.end()) { + continue; + } + InlinedVector> new_combination = current_combination; + new_combination.push_back(plan); + IterateNodeOptimizationPlan(plan, node_to_optimization_plans_map, new_combination, logger, all_combinations); + } + MO_LOG_DEBUG_INFO(logger, "Exit IterateNode: " + node->Name()); +} + +void ListAllCombinations(const InlinedVector>>>& + all_possible_node_optimization_plans, + int index, + const InlinedVector>& current_combination, + const logging::Logger& logger, + InlinedVector>>& + all_combinations) { + MO_LOG_DEBUG_INFO(logger, "Enter ListAllCombinations"); + if (index == static_cast(all_possible_node_optimization_plans.size())) { + if (std::find(all_combinations.begin(), all_combinations.end(), current_combination) == + all_combinations.end()) { + all_combinations.push_back(current_combination); + } + MO_LOG_DEBUG_INFO(logger, "Exit ListAllCombinations after finding a new combination"); + return; + } + + for (const auto& plans : all_possible_node_optimization_plans[index]) { + for (const auto& plan : plans) { + InlinedVector> new_combination = current_combination; + new_combination.push_back(plan); + ListAllCombinations(all_possible_node_optimization_plans, index + 1, new_combination, logger, all_combinations); + } + } + + MO_LOG_DEBUG_INFO(logger, "Exit ListAllCombinations"); +} + +/** + * Iterate from a node optimization plan, if there is any buffer reuse in its node outputs, + * iterate all possible reuse buffer plan combinations. + */ +void IterateNodeOptimizationPlan(const std::shared_ptr& plan, + const InlinedHashMap>>& + node_to_optimization_plans_map, + const InlinedVector>& + current_combination, + const logging::Logger& logger, + InlinedVector>>& + all_combinations) { + MO_LOG_DEBUG_INFO(logger, "Enter IterateNodeOptimizationPlan: " + plan->GetClusterId()); + + // No reuse buffer, don't need to iterate further, we found a plan combination already. + if (plan->reuse_buffers.size() == 0) { + MO_LOG_DEBUG_INFO(logger, "length of current_combination: " + + std::to_string(current_combination.size()) + ", " + plan->GetClusterId()); + all_combinations.push_back(current_combination); + MO_LOG_DEBUG_INFO(logger, "Exit IterateNodeOptimizationPlan"); + return; + } + + InlinedVector>>> + all_possible_node_optimization_plans; + all_possible_node_optimization_plans.resize(plan->reuse_buffers.size()); + + size_t i = 0; + for (const auto& p : plan->reuse_buffers) { + MO_LOG_DEBUG_INFO(logger, ">>>reuse buffer: " + std::to_string(p.first)); + IterateNode(p.second.first, node_to_optimization_plans_map, {}, logger, all_possible_node_optimization_plans[i]); + ++i; + } + + ListAllCombinations(all_possible_node_optimization_plans, 0, current_combination, logger, all_combinations); + + MO_LOG_DEBUG_INFO(logger, "Exit IterateNodeOptimizationPlan: " + plan->GetClusterId()); +} + +// Return a deterministic string for multiple plans combinations. +std::string GetMultiplePlanClusterId(const InlinedVector>& plans) { + constexpr const int request_count = -1; // -1 means apply optimization to all appearances. + + std::ostringstream oss; + InlinedVector sorted_plans; + for (const auto& plan : plans) { + sorted_plans.push_back(plan->GetClusterId() + ":" + std::to_string(static_cast(plan->GetOptimizationType())) + + ":" + std::to_string(request_count)); + } + + std::sort(sorted_plans.begin(), sorted_plans.end()); + + for (const auto& plan : sorted_plans) { + if (oss.str().size() > 0) { + oss << ","; + } + oss << plan; + } + return oss.str(); +} + +void GetMemorySavingSymbolicString(const MemoryOptimizationPlanner& memory_opt_planner, + const logging::Logger& logger, + std::map>& + combination_cluster_ids_to_saved_symbolic_byte_map) { + // Group by "ClusterId:OptimizationType:RequestCount". + InlinedVector>> all_combinations; + + combination_cluster_ids_to_saved_symbolic_byte_map.clear(); + const auto& node_to_optimization_plan_map = memory_opt_planner.GetNodeToOptimizationPlanMap(); + for (const auto& node_to_optimization_plan : node_to_optimization_plan_map) { + const auto& node = node_to_optimization_plan.first; + InlinedVector> current_combination; + MO_LOG_DEBUG_INFO(logger, ">>>Start looping node: " + node->Name()); + IterateNode(node, node_to_optimization_plan_map, current_combination, logger, all_combinations); + MO_LOG_DEBUG_INFO(logger, "<<Name()); + } + + for (const auto& combination : all_combinations) { + std::string combination_cluster_id = GetMultiplePlanClusterId(combination); + std::string symbolic_byte_count = ""; + for (const auto& plan : combination) { + if (symbolic_byte_count.size() > 0) { + symbolic_byte_count += " + "; + } + symbolic_byte_count += plan->GetMemorySavingSymbolicString(); + } + + if (symbolic_byte_count.size() > 0) { + symbolic_byte_count = "(" + symbolic_byte_count + ")"; + } + auto& p = combination_cluster_ids_to_saved_symbolic_byte_map[combination_cluster_id]; + const auto& original = p.first; + if (original.size() > 0) { + symbolic_byte_count = original + " + " + symbolic_byte_count; + } + + MO_LOG_DEBUG_INFO(logger, "combination_cluster_id: " + combination_cluster_id + + ", symbolic_byte_count: " + symbolic_byte_count); + + p.first = symbolic_byte_count; + p.second += 1; + } +} + +namespace { + +template +std::string ToFixedLengthString(T value, int length) { + std::ostringstream oss; + oss << std::setw(length) << std::left; + oss << value; + return oss.str(); +} + +void FormatRecomputeMemoryRecords(int option_index, + const MemoryRecord& record, + bool compromise_recompute, + InlinedVector& rows) { + const auto subgraph_str = compromise_recompute ? record.recompute_with_compromise_subgraph_str + : record.recompute_subgraph_str; + const auto opt_type = compromise_recompute ? OptimizationType::RecomputeWithCompromise + : OptimizationType::Recompute; + const auto request_count = compromise_recompute ? record.request_recompute_with_compromise_count + : record.request_recompute_count; + const auto actual_count = compromise_recompute ? record.actual_recompute_with_compromise_count + : record.actual_recompute_count; + + const std::string empty_first_col = "|" + ToFixedLengthString(std::string(), kFirstColumnWidth) + "|"; + + rows.push_back(empty_first_col); + rows.push_back(empty_first_col + + ToFixedLengthString(">>Option " + std::to_string(option_index), kTitleWidthInSecondColumn) + ": " + + OptimizationTypeToString(opt_type) + " subgraph " + subgraph_str); + + if (request_count) { + // Only show this if user requested it. + rows.push_back( + empty_first_col + + ToFixedLengthString(" Status", kTitleWidthInSecondColumn) + ": " + "Enabled, requested count=" + + std::to_string(request_count) + + ", actual applied count=" + std::to_string(actual_count)); + } else { + rows.push_back(empty_first_col + ToFixedLengthString(" Status", kTitleWidthInSecondColumn) + + ": Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=" + + subgraph_str + ":" + std::to_string(static_cast(opt_type)) + ":-1"); + } + + std::string activation_str = empty_first_col + " Stashed Activations: "; + rows.push_back(activation_str); + + const auto& reused_buffers = compromise_recompute ? record.output_port_reuse_recompute_with_compromise_count + : record.output_port_reuse_recompute_count; + if (reused_buffers.size() > 0) { + std::string reused_buffers_summary = empty_first_col + ToFixedLengthString(" - ReuseFreq", kTitleWidthInSecondColumn) + ": "; + for (const auto& p : reused_buffers) { + reused_buffers_summary += " Output " + std::to_string(p.first) + "(" + std::to_string(p.second) + "),"; + } + + rows.push_back(reused_buffers_summary); + } + + const auto activation_count = compromise_recompute ? record.compromise_recomputed_outputs.size() + : record.recomputed_outputs.size(); + for (size_t i = 0; i < activation_count; ++i) { + const MemoryRecord::OutputStat* stat; + if (compromise_recompute) { + stat = &record.compromise_recomputed_outputs[i]; + } else { + stat = &record.recomputed_outputs[i]; + } + + rows.push_back(empty_first_col + + ToFixedLengthString(" - Output " + std::to_string(stat->output_index), kTitleWidthInSecondColumn) + + ": [" + stat->output_shape_str + "], byte/elem: " + + std::to_string(stat->output_byte_count_per_element) + + ", " + std::to_string(static_cast(stat->saving_ratio * 100)) + + "% saved"); + } +} +} // namespace + +std::string SerializeMemoryRecords( + const std::vector>& records_grouped_by_node_cluster_id, + std::string_view user_config) { + InlinedVector rows; + rows.push_back(kTableBorder); + rows.push_back("|" + ToFixedLengthString("Freq", kFirstColumnWidth) + + "| Memory Optimization Opportunities (Clustered by node-level activation patterns)"); + rows.push_back(kTableRowSeparator); + + for (const auto& p : records_grouped_by_node_cluster_id) { + const auto& record = p.second; + rows.push_back("|" + ToFixedLengthString(record.freq, kFirstColumnWidth) + + "|For each row options are mutually exclusive, only one of them can be enabled."); + + int option_index = 1; + if (record.recomputed_outputs.size() > 0) { + FormatRecomputeMemoryRecords(option_index, record, false, rows); + option_index++; + } + + if (record.compromise_recomputed_outputs.size() > 0) { + FormatRecomputeMemoryRecords(option_index, record, true, rows); + option_index++; + } + rows.push_back(kTableRowSeparator); + } + + rows.push_back(kTableBorder); + + size_t max_length = 0; + for (auto& row : rows) { + max_length = std::max(max_length, row.length()); + } + + // Example is: + // static const std::string row_separator = + // "|_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _|\n"; + static const std::string kTableRowSeparatorStart = "|_ _ _ _|"; + size_t second_row_length = max_length - kTableRowSeparatorStart.length(); + if (second_row_length % 2 == 0) { + second_row_length += 2; + max_length += 2; + } else { + second_row_length += 3; // add 3 to make it even + max_length += 3; + } + std::string row_separator_full(second_row_length, ' '); + for (size_t i = 0; i < row_separator_full.size() - 1; ++i) { + if (i % 2 == 0) { + row_separator_full[i] = '_'; + } + } + row_separator_full[row_separator_full.size() - 1] = '|'; + row_separator_full = kTableRowSeparatorStart + row_separator_full; + + std::string table_border_full(max_length, '='); + std::ostringstream summary; + summary << std::endl; + summary << MakeString("MemoryInsight Summary - User config: ", (user_config.empty() ? "not provided" : user_config)) + << std::endl; + for (auto& row : rows) { + if (row == kTableRowSeparator) { + summary << row_separator_full << std::endl; + } else if (row == kTableBorder) { + summary << table_border_full << std::endl; + } else { + std::string filled_up = std::string(max_length - row.length(), ' '); + filled_up[filled_up.length() - 1] = '|'; + summary << row << filled_up << std::endl; + } + } + summary << "Note: use comma as a separator for enabling more than one subgraphs." << std::endl; + return summary.str(); +} + +std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer, + std::string_view memory_optimization_config, + std::string_view recompute_probe_level, + const logging::Logger& logger, + std::map>& + cluster_id_combinations_to_saved_symbolic_byte_map, + const OrtValueNameIdxMap* ortvalue_name_to_idx_map, + const SequentialExecutionPlan* p_seq_exec_plan) { + ProbeLevel probe_level = ProbeLevel::Advanced; + if (!recompute_probe_level.empty()) { + int probe_level_int = ParseIntValueFromString(recompute_probe_level); + ORT_ENFORCE(probe_level_int < static_cast(ProbeLevel::LevelMax) && + probe_level_int >= 0, + "Invalid probe level specified: ", recompute_probe_level); + probe_level = static_cast(probe_level); + } + + ptrdiff_t yield_op_order_in_topological_sort; + InlinedHashMap> candidate_output_args_map; + InlinedHashMap node_index_to_its_order_in_topological_sort_map; + + // The first pass - find the candidate subgraphs. + MemoryOptimizationPlanner memory_opt_planner; + ORT_ENFORCE(FindORTModuleMemoryOpportunity( + graph_viewer, + probe_level, + logger, + node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, + candidate_output_args_map, + memory_opt_planner) + .IsOK()); + + InlinedHashMap cluster_id_to_config_map; + // Finalize the plan according to user config, + // then create a ClusterApplyContext for each unique cluster (having the same node pattern) + + NodeToClusterApplyContextMap node_to_apply_context_map; + + if (!memory_optimization_config.empty()) { + ORT_ENFORCE(ParseConfigFromString(memory_optimization_config, cluster_id_to_config_map) + .IsOK()); + InlinedHashMap> node_to_opt_plan_map; + ORT_ENFORCE(memory_opt_planner.FinalizeNodePlansFromUserConfig(cluster_id_to_config_map, + node_to_opt_plan_map, + node_to_apply_context_map) + .IsOK()); + } + + if (ortvalue_name_to_idx_map != nullptr && p_seq_exec_plan != nullptr) { + ORT_ENFORCE(memory_opt_planner.UpdateNodePlansFromExecutionPlan(graph_viewer, + *ortvalue_name_to_idx_map, + *p_seq_exec_plan) + .IsOK()); + } + + std::vector> records; + GetMemoryRecordsGroupedByNodeClusterId(memory_opt_planner, node_to_apply_context_map, records); + + GetMemorySavingSymbolicString(memory_opt_planner, logger, cluster_id_combinations_to_saved_symbolic_byte_map); + + return SerializeMemoryRecords(records, memory_optimization_config); +} + +} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h new file mode 100644 index 0000000000000..c4267efdbea51 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.h @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h" +#include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" + +namespace onnxruntime::optimizer::memory_optimizer { + +/** + * @brief A data structure to store memory optimization statistics for a specific node cluster id. + * + * We will collect statistics for each node cluster id. + * The node cluster id is generated from all possible optimization plans for a specific node, plus shape, data type, + * outputs, etc. For the nodes have the same node cluster id, they will have one single MemoryRecord, displayed + * as a row in the final memory optimization statistics table. + */ +class MemoryRecord { + public: + class OutputStat { + public: + OutputStat(size_t output_index, std::string_view output_shape, size_t output_byte_count_per_element, + float saving_ratio) + : output_index(output_index), + output_shape_str(output_shape), + output_byte_count_per_element(output_byte_count_per_element), + saving_ratio(saving_ratio) {} + + // output index, shape, byte count per element, saving ratio + size_t output_index; + std::string output_shape_str; + size_t output_byte_count_per_element; + float saving_ratio; + }; + + // Recompute Column + std::string recompute_subgraph_str; + InlinedVector recomputed_outputs; + int request_recompute_count = 0; + int actual_recompute_count = 0; + InlinedHashMap output_port_reuse_recompute_count; + + // RecomputeWithCompromise Column + std::string recompute_with_compromise_subgraph_str; + InlinedVector compromise_recomputed_outputs; + int request_recompute_with_compromise_count = 0; + int actual_recompute_with_compromise_count = 0; + InlinedHashMap output_port_reuse_recompute_with_compromise_count; + + // Frequency Column + int freq = 0; +}; + +/** + * @brief Iterate the graph and find all possible memory optimization opportunities for related nodes. + * + * @param graph_viewer The graph to iterate. + * @param probe_level The level to control allowed operations during recomputable subgraph detecting. + * @param logger Logger. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * @param yield_op_order_in_topological_sort The order of the boundary op in the topological sort. + * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and + * @param mem_opt_stats A store to maintain all found optimization plans for related nodes. + * @return Status + */ +Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, + const ProbeLevel probe_level, + const logging::Logger& logger, + InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + ptrdiff_t& yield_op_order_in_topological_sort, + InlinedHashMap>& candidate_output_args_map, + MemoryOptimizationPlanner& mem_opt_stats); + +/** + * @brief From the optimization plans, generate the memory optimization statistics table containing many MemoryRecords, + * each represents one node cluster id. + * + * @param memory_opt_planner The optimization planner to get optimization plans. + * @param node_to_apply_contexts_map The optimization applying information. + * @param generated_records Returns the generated memory optimization statistics table. + * (for example, how many are actually applied) to each MemoryRecord. + */ +void GetMemoryRecordsGroupedByNodeClusterId(const MemoryOptimizationPlanner& memory_opt_planner, + const NodeToClusterApplyContextMap& + node_to_apply_contexts_map, + std::vector>& generated_records); + +/** + * @brief Serialize the memory optimization statistics table to a string. + * + * @param records_grouped_by_node_cluster_id The memory optimization statistics table. + * @param user_config The user configuration to the serialized string. + * @return std::string + */ +std::string SerializeMemoryRecords(const std::vector>& + records_grouped_by_node_cluster_id, + std::string_view user_config); + +/** + * @brief A public API exposed to retrieve the memory optimization statistics table, given a graph. + * + * If possible, session's allocation plans and execution plan will also be available to help the analysis. + * + * @param graph_viewer The graph to analyze. + * @param memory_optimization_config The user configuration to control the memory optimization. + * @param recompute_probe_level The level to control allowed operations during recomputable subgraph detecting. + * @param logger Logger. + * @param ortvalue_name_to_idx_map Optional. If provided, we will use it to map ort value name to index. + * @param p_seq_exec_plan Optional. If provided, we will use it to get allocation plans. + * @return std::string + */ +std::string GetSerializedORTModuleMemoryStat(const GraphViewer& graph_viewer, + std::string_view memory_optimization_config, + std::string_view recompute_probe_level, + const logging::Logger& logger, + // used as Python binding, so used std::map instead of InlinedHashMap + std::map>& + cluster_id_combinations_to_saved_symbolic_byte_map, + const OrtValueNameIdxMap* ortvalue_name_to_idx_map = nullptr, + const SequentialExecutionPlan* p_seq_exec_plan = nullptr); + +} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc new file mode 100644 index 0000000000000..7e042031f66a2 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.cc @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/graph/graph_utils.h" +#include "core/optimizer/utils.h" +#include "core/framework/ort_value_name_idx_map.h" +#include "core/framework/sequential_execution_plan.h" + +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h" + +namespace onnxruntime::optimizer::memory_optimizer { + +std::string NodeOptimizationPlanBase::GetMemorySavingSymbolicString() const { + std::string saving_str; + for (auto output_index : activation_output_indices_) { + // If the output is reusing other node's buffer, then no memory saving. + if (reuse_buffers.find(output_index) != reuse_buffers.end()) { + continue; + } + + const auto& output_def = node->OutputDefs()[output_index]; + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(*output_def->TypeAsProto()); + ORT_ENFORCE(ml_data_type->IsTensorType(), "ml_type must be a tensor type, but it is ", + DataTypeImpl::ToString(ml_data_type)); + const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); + ORT_ENFORCE(nullptr != tensor_type_base); + MLDataType elt_type = tensor_type_base->GetElementType(); + const auto byte_count_per_element = elt_type->Size(); + if (!saving_str.empty()) { + saving_str += " + "; + } + saving_str = "(" + GetTensorElemCountInSymbolicString(node, output_index) + " * " + + std::to_string(byte_count_per_element) + " * " + + std::to_string(GetSaveRatio()) + ")"; + } + if (saving_str.empty()) { + return saving_str; + } + return "(" + saving_str + ")"; +} + +Status MemoryOptimizationPlanner::UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer, + const OrtValueNameIdxMap& ortvalue_name_to_idx_map, + const SequentialExecutionPlan& p_seq_exec_plan) { + InlinedHashMap idx_to_ortvalue_name_map; + for (const auto& entry : ortvalue_name_to_idx_map) { + idx_to_ortvalue_name_map[entry.second] = entry.first; + } + + for (const auto& node_to_optimization_plan : node_to_optimization_plans_map) { + const auto& node_plans = node_to_optimization_plan.second; + + for (auto& node_plan : node_plans) { + const std::string cluster_id = node_plan->GetClusterId(); + const Node* node = node_plan->node; + for (auto& output_index : node_plan->GetActivationOutputIndices()) { + const NodeArg* node_arg = node->OutputDefs()[output_index]; + const auto& ort_value_name = node_arg->Name(); + int ort_value_idx; + ORT_ENFORCE(ortvalue_name_to_idx_map.GetIdx(ort_value_name, ort_value_idx).IsOK()); + const auto& alloc_plan = p_seq_exec_plan.allocation_plan; + ORT_ENFORCE(ort_value_idx >= 0 && static_cast(ort_value_idx) < alloc_plan.size()); + const auto& per_alloc_plan = alloc_plan[ort_value_idx]; + if (per_alloc_plan.alloc_kind != AllocKind::kReuse) { + continue; + } + int reused_ort_value_idx = per_alloc_plan.reused_buffer; + const auto& reused_ort_value_name = idx_to_ortvalue_name_map.at(reused_ort_value_idx); + + const Node* p_node = graph_viewer.GetProducerNode(reused_ort_value_name); + if (p_node == nullptr) { + // This is a graph input. + continue; + } + + int src_op_output_index = optimizer_utils::IndexOfNodeOutput(*p_node, *node_arg); + node_plan->reuse_buffers[output_index] = std::make_pair(p_node, src_op_output_index); + } + } + } + + return Status::OK(); +} + +Status MemoryOptimizationPlanner::FinalizeNodePlansFromUserConfig( + const InlinedHashMap& cluster_id_to_user_configs, + InlinedHashMap>& node_to_opt_plan_map, + NodeToClusterApplyContextMap& node_to_apply_context_map) const { + if (cluster_id_to_user_configs.size() == 0) { + return Status::OK(); + } + + // Create a temporary map to store the apply context for each cluster pattern. + InlinedHashMap> cluster_id_to_apply_contexts_map; + + // We loop all nodes' optimization plans and find the match in user configs. + // If found in user configs, we finalize the plan and create/update the apply context for this node. + // If not found in user configs, we will not include the node in the returned result. + for (const auto& node_to_optimization_plan : node_to_optimization_plans_map) { + const auto& node = node_to_optimization_plan.first; + const auto& node_plans = node_to_optimization_plan.second; + + for (auto& node_plan : node_plans) { + const std::string cluster_id = node_plan->GetClusterId(); + if (cluster_id_to_user_configs.find(cluster_id) == cluster_id_to_user_configs.end()) { + continue; + } + + const auto& user_config = cluster_id_to_user_configs.at(cluster_id); + if (node_plan->GetOptimizationType() == user_config.type) { + // First finalize the plan for this node. + node_to_opt_plan_map[node] = node_plan; + + // Create/Update the apply context for this node. + if (cluster_id_to_apply_contexts_map.find(cluster_id) == cluster_id_to_apply_contexts_map.end()) { + std::shared_ptr apply_context = std::make_shared(); + apply_context->requested_count = user_config.requested_count; + apply_context->type = user_config.type; + apply_context->total_frequency++; + cluster_id_to_apply_contexts_map.insert({cluster_id, apply_context}); + } + + node_to_apply_context_map[node] = cluster_id_to_apply_contexts_map.at(cluster_id); + + // If different plans for the same node have same cluster id, we only need to finalize the first one. + // The rest of them will be ignored. + break; + } + } + } + + return Status::OK(); +} + +} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h new file mode 100644 index 0000000000000..0e5e2967ec15a --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/optimization_planner.h @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "core/framework/ort_value_name_idx_map.h" +#include "core/framework/sequential_execution_plan.h" + +namespace onnxruntime::optimizer::memory_optimizer { + +/** + * @brief Struct to store properties of a specific subgraph. + */ +class ClusterApplyContext { + public: + ClusterApplyContext() = default; + + OptimizationType type; + int requested_count{0}; + int total_frequency{0}; // The occurrence of this subgraph pattern in the graph. + + int applied_count{0}; // The number of times this subgraph pattern has been really applied in this transformer. + int skip_count{0}; // The number of times this subgraph instance has been skipped in reversed topological order. +}; + +/** + * @brief Base class for a concrete optimization plan. + * + */ +class NodeOptimizationPlanBase { + public: + NodeOptimizationPlanBase(const Node* node, + gsl::span activation_output_indices, + float save_ratio) + : node(node), + activation_output_indices_(activation_output_indices.begin(), activation_output_indices.end()), + save_ratio_(save_ratio) { + } + + virtual ~NodeOptimizationPlanBase() = default; + + virtual OptimizationType GetOptimizationType() const = 0; + + /** + * Get the cluster id for this optimization plan. + * This cluster id is used to enable the optimization as a unique identity, for example, for recompute it is a + * subgraph string representation. + * @return std::string + */ + virtual std::string GetClusterId() const = 0; + + /** + * Get a string used to generate node cluster id for this optimization plan. + * Node cluster id is on Node level, each node can have multiple optimization plans, each plan generates its + * normalization string. Once combined we get Node cluster id. This id is used to categorize nodes into different + * groups, showing them as one row in memory optimization opportunity table. + * @return std::string + */ + virtual std::string NormalizeForNodeClusterId() const = 0; + + /** + * Return all output indices that are used as activation buffers. + */ + gsl::span GetActivationOutputIndices() const { return activation_output_indices_; } + + /** + * Return the saving ratio for this optimization plan. + */ + float GetSaveRatio() const { return save_ratio_; } + + /** + * Get a symbolic string to represent the memory saving for this optimization plan. + */ + std::string GetMemorySavingSymbolicString() const; + + const Node* node; + // A map: output index reusing other node's output (other_node, output index) + InlinedHashMap reuse_buffers; + + private: + InlinedVector activation_output_indices_; + float save_ratio_ = 1.0f; +}; + +using NodeToClusterApplyContextMap = InlinedHashMap>; + +class MemoryOptimizationPlanner { + public: + void AddNodeOptimizationPlan(const Node* node, + std::shared_ptr plan) { + if (node_to_optimization_plans_map.find(node) == node_to_optimization_plans_map.end()) { + node_to_optimization_plans_map.insert({node, {}}); + } + + node_to_optimization_plans_map[node].emplace_back(plan); + } + + Status UpdateNodePlansFromExecutionPlan(const GraphViewer& graph_viewer, + const OrtValueNameIdxMap& ortvalue_name_to_idx_map, + const SequentialExecutionPlan& p_seq_exec_plan); + + Status FinalizeNodePlansFromUserConfig( + const InlinedHashMap& cluster_id_to_user_configs, + InlinedHashMap>& node_to_opt_plan_map, + NodeToClusterApplyContextMap& node_to_apply_context_map) const; + + std::string GenerateNodeClusterId(const Node* node) const { + ORT_ENFORCE(node_to_optimization_plans_map.find(node) != node_to_optimization_plans_map.end(), + "Node not found in node_to_optimization_plans_map."); + std::ostringstream oss; + const auto& node_plans = node_to_optimization_plans_map.at(node); + for (auto& plan : node_plans) { + oss << plan->NormalizeForNodeClusterId(); + } + + return oss.str(); + } + + const InlinedHashMap>>& + GetNodeToOptimizationPlanMap() const { + return node_to_optimization_plans_map; + } + + private: + InlinedHashMap>> node_to_optimization_plans_map; +}; + +} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc new file mode 100644 index 0000000000000..0782cbdae2eec --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -0,0 +1,405 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "orttraining/core/optimizer/memory_optimizer/recompute_analysis.h" +#include "core/framework/data_types.h" + +namespace onnxruntime::optimizer::memory_optimizer { + +namespace { + +constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15; + +static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { + const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); + MLDataType ml_data_type = DataTypeImpl::TypeFromProto(type_proto); + const TensorTypeBase* tensor_type_base = ml_data_type->AsTensorType(); + ORT_ENFORCE(nullptr != tensor_type_base); + MLDataType elt_type = tensor_type_base->GetElementType(); + return elt_type->Size(); +} + +// TODO(pengwa): extent this function to be more general. +float InputOutputSizeRatio(const Node* node) { + if (node->OpType().compare("Cast") == 0) { + const NodeArg* input = node->InputDefs()[0]; + const NodeArg* output = node->OutputDefs()[0]; + if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING || + output->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + return 1.0f; + } + const auto& ptype1 = input->Type(); + const auto& ptype2 = output->Type(); + float ratio = static_cast(GetElementSize(ptype1)) / static_cast(GetElementSize(ptype2)); + return ratio; + } + + return 1.0f; +} + +/** + * @brief Used to define per-op recompute config. + * + */ +struct AllowedRecomputeNodeConfig { + InlinedVector input_arg_indices; // input index to iterate further (bottom up) +}; + +// The op types that are supported predefined. + +const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { + static InlinedHashMap> recomputable_op_table_map; + if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) { + return recomputable_op_table_map.at(probe_op_level); + } + + recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); + auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level); + if (probe_op_level >= static_cast(ProbeLevel::Basic)) { + recomputable_op_table.insert({ + // Binary elementwise + {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, + {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, + + // Data layout + /// The shape input is trivial whether it exists or not in backward. + {"Reshape", AllowedRecomputeNodeConfig{{0}}}, + {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, + {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, + + // Unary elementwise + /// The ratio and mode input are trivial whether they exist or not in backward + {"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}}, + /// The axis input is trivial whether it exists or not in backward + {"CumSum", AllowedRecomputeNodeConfig{{0}}}, + {"Dropout", AllowedRecomputeNodeConfig{{0}}}, + {"Gelu", AllowedRecomputeNodeConfig{{0}}}, + {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, + + // Ternary elementwise + {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, + + // Data copy + {"Tile", AllowedRecomputeNodeConfig{{0}}}, + {"Cast", AllowedRecomputeNodeConfig{{0}}}, + }); + } + + if (probe_op_level >= static_cast(ProbeLevel::Advanced)) { + recomputable_op_table.insert({ + {"MatMul", AllowedRecomputeNodeConfig{{0, 1}}}, + {"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}}, + {"Softmax", AllowedRecomputeNodeConfig{{0}}}, + {"BiasSoftmax", AllowedRecomputeNodeConfig{{0, 1}}}, + {"BiasSoftmaxDropout", AllowedRecomputeNodeConfig{{0, 1}}}, + }); + } + + return recomputable_op_table; +} + +/** + * @brief Check whether a node is a recomputable node at given probe level. + */ +bool IsRecomputable(const Node& node, ProbeLevel probe_level) { + const auto& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); + return op_table.find(node.OpType()) != op_table.end(); +} + +/** + * @brief Find recomputable subgraphs (has at least one nodes, at most MAXIMUM_RECOMPUTE_NODE_COUNT nodes). + * + * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. + * @param node_output_index_candidates Candidate output indices of "node", which are consumed by both fw and bw ops. + * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * Used to re-order the collected subgraph nodes. + * @param nodes_in_topological_order Collected vector of nodes of found subgraph, in the order of the topological + * sorted. + * @param logger Logger. + * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a + * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the + * size of stashed activation. + * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a + * compromised subgraph. + * @param save_ratio The ratio of memory saving if we can find a recomputable subgraph. + * @return Status + */ +Status SelectRecomputeSubgraph(const Node& entry_node, + const ProbeLevel probe_level, + const InlinedVector& node_output_index_candidates, + const ActivationUsedMap& fw_op_output_arg_used_map, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const logging::Logger& logger, + InlinedVector& nodes, + bool compromise_stashed_activation, + bool& can_compromise_stashed_activation, + float& save_ratio) { + const auto& recomputable_op_table = GetAllowedRecomputeOps(static_cast(probe_level)); + + can_compromise_stashed_activation = false; + + LOGS(logger, VERBOSE) << "Enter SelectRecomputeSubgraph for Node " << entry_node.Name() << "(" + << entry_node.OpType() << ")"; + nodes.clear(); + + std::deque q; + for (auto output_index : node_output_index_candidates) { + q.push_back(NodeOutputPort(&entry_node, output_index)); + } + + bool early_stop = false; + std::set visited_output_arg_set; + std::set visited_node_set; + + // For the initial activations in queue, they are stashed ones, so we do differently when scanning the queue for them. + bool is_first_queue_scan = true; + while (nodes.size() < MAXIMUM_RECOMPUTE_NODE_COUNT && !q.empty() && !early_stop) { + // Loop all candidate NodeOutputPort, and find the next layer of input nodes. + size_t current_queue_size = q.size(); + for (size_t i = 0; i < current_queue_size; ++i) { + NodeOutputPort p = q.front(); + q.pop_front(); + const Node* curr_node = p.first; + + // Skip if the node output is already visited. + if (std::find(visited_output_arg_set.begin(), visited_output_arg_set.end(), p) != + visited_output_arg_set.end()) { + continue; + } + + visited_output_arg_set.insert({p}); + + // If the node is already visited by from its other output index, skip it. + if (visited_node_set.find(curr_node) != visited_node_set.end()) { + continue; + } + + visited_node_set.insert(curr_node); + + // Bottom-up search rules. + // If current op is entry output node (that generates stashed activations): + // 1. If the op is not in recomputable_op_table, skip it. + // Otherwise: + // If current op is in allowed list, check its input args, and append the producers' NodeOutputPorts to next_q. + // If current op is NOT in allowed list: + // 1). the output does not exist in backward, we cannot find a good solution for so, the search terminates. + // 2). the output is used in backward, we don't need to trace back further, so continue searching. + auto op_recompute_config_it = recomputable_op_table.find(curr_node->OpType()); + auto cur_output_arg_name = curr_node->OutputDefs()[p.second]->Name(); + if (is_first_queue_scan) { + // We handle the entry node outputs differently because, we don't want this case falls into and succeed one of + // the checks in the other branch + // 1. "op is not in recompute op list, but its output is used in backward" + // 2. "op is in recompute op list, but its output is used in backward" + // (either of the above checks is true for entry node outputs) + if (op_recompute_config_it == recomputable_op_table.end()) { + early_stop = true; + LOGS(logger, VERBOSE) << "Entry Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** " + << "in recompute op list, search terminates."; + break; + } + } else { + if (op_recompute_config_it == recomputable_op_table.end()) { + if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in " + << "recompute op list, but its output [" << cur_output_arg_name << "] is used in " + << "backward, we don't need trace bottom-up further. Entry node: " + << entry_node.Name() << "(" << entry_node.OpType() << ")"; + continue; + } else { + early_stop = true; + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is **NOT** in " + << "recompute op list, and its output [" << cur_output_arg_name + << "] does not exist in backward, search terminates. Entry node: " + << entry_node.Name() << "(" << entry_node.OpType() << ")"; + break; + } + } + + if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") " + << "is in recompute op list, while its output [" << cur_output_arg_name + << "] is used in backward, we don't need trace bottom-up further. Entry node: " + << entry_node.Name() << "(" << entry_node.OpType() << ")"; + continue; + } + } + + // Append node to the selected graph. + if (std::find(nodes.begin(), nodes.end(), curr_node) == nodes.end()) { + nodes.push_back(curr_node); + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() + << ") is added in selected subgraph "; + } + + // This check is not matured now, subject to change. + float ratio = InputOutputSizeRatio(curr_node); + float saving_ratio = 1.0f - ratio; + float is_current_node_compromisable = (ratio < 1.f); + can_compromise_stashed_activation = can_compromise_stashed_activation || is_current_node_compromisable; + if (is_current_node_compromisable) { + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() + << ") has input/output size " << ratio << " < 1.f, can compromise stashed activation"; + } + + if (is_current_node_compromisable && compromise_stashed_activation) { + LOGS(logger, VERBOSE) << "Node " << curr_node->Name() << "(" << curr_node->OpType() << ") is in " + << "recompute op list, and its output [" << cur_output_arg_name + << "] does not exist in backward, while it meets compromised check, we don't need trace " + << "bottom-up further."; + save_ratio = saving_ratio; + continue; + } + + // Iterate all input nodes according to allowed input arg index of the entry node. + const auto& input_arg_indices = op_recompute_config_it->second.input_arg_indices; + for (auto it = curr_node->InputEdgesBegin(), end = curr_node->InputEdgesEnd(); it != end; ++it) { + const Node::EdgeEnd& input_edge = *it; + const auto& parent_node = input_edge.GetNode(); + const auto parent_node_output_index = input_edge.GetSrcArgIndex(); + const auto current_node_input_index = input_edge.GetDstArgIndex(); + if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != + input_arg_indices.end()) { + NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); + + LOGS(logger, VERBOSE) << "Node " << parent_node.Name() << "(" << parent_node.OpType() << ")'s " + << parent_node_output_index + << "th output [" << parent_node.OutputDefs()[parent_node_output_index]->Name() + << "] is added in recompute search list "; + + q.push_back(next_p); + } + } + } + // After handling all entry node outputs, we set the flag to false. + is_first_queue_scan = false; + } + + // If input args are not found in bw, but op count exceed MAXIMUM_RECOMPUTE_NODE_COUNT, skip recompute. + if (!q.empty() || early_stop) { + LOGS(logger, VERBOSE) << "Fail to find a solution for recompute: current node count is " << nodes.size() + << ", queue size: " << q.size() << ", early stop: " << early_stop; + nodes.clear(); + } else { + // Re-order the nodes in topological order. + std::sort(nodes.begin(), nodes.end(), + [&node_index_to_its_order_in_topological_sort_map](const Node*& lhs, const Node*& rhs) { + return node_index_to_its_order_in_topological_sort_map.at(lhs->Index()) < + node_index_to_its_order_in_topological_sort_map.at(rhs->Index()); + }); + } + return Status::OK(); +} + +/** + * @brief Convert the recompute subgraph to its string representation. + * + * @param nodes_in_topological_order The subgraph nodes in topological order. + * @param subgraph_string_representation Returns subgraph string representation. + * @param log_info Returns log info for users. + */ +void NodesInTopoOrderToString(gsl::span nodes_in_topological_order, + std::string& subgraph_string_representation, + std::string& log_info) { + std::ostringstream oss; + std::ostringstream subgraph_string_representation_oss; + size_t node_count = nodes_in_topological_order.size(); + for (size_t i = 0; i < node_count; ++i) { + if (i < node_count - 1) { // Ignore the last node. + oss << "(name:" << nodes_in_topological_order[i]->Name() << ", type:" << nodes_in_topological_order[i]->OpType() + << "),"; + } + + subgraph_string_representation_oss << nodes_in_topological_order[i]->OpType() << "+"; + } + + subgraph_string_representation = subgraph_string_representation_oss.str(); + log_info = oss.str(); + if (log_info.size() > 0) { + log_info = " with its precedent nodes: " + log_info; + } +} + +} // namespace + +std::unique_ptr CheckNodeForRecompute(const Node& node, + const ProbeLevel probe_level, + const ActivationUsedMap& fw_op_output_arg_used_map, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const InlinedHashMap>& + candidate_output_args_map, + const logging::Logger& logger, + bool compromise_stashed_activation, + bool& can_compromise_stashed_activation) { + if (!IsRecomputable(node, probe_level)) { + return nullptr; + } + + InlinedVector nodes_in_topological_order; + float save_ratio = 1.f; + ORT_ENFORCE(SelectRecomputeSubgraph(node, + probe_level, + candidate_output_args_map.at(&node), + fw_op_output_arg_used_map, + node_index_to_its_order_in_topological_sort_map, + logger, + nodes_in_topological_order, + compromise_stashed_activation, + can_compromise_stashed_activation, + save_ratio) + .IsOK()); + if (nodes_in_topological_order.size() == 0) { + return nullptr; + } + + std::string subgraph_str_representation, log_info; + NodesInTopoOrderToString(nodes_in_topological_order, subgraph_str_representation, log_info); + + LOGS(logger, VERBOSE) << "Node " << node.Name() << "(" << node.OpType() << ") can be recomputed" << log_info; + + return std::make_unique(&node, candidate_output_args_map.at(&node), + nodes_in_topological_order, + compromise_stashed_activation, + save_ratio); +} + +std::string NodeRecomputePlan::GetClusterId() const { + std::ostringstream oss; + oss << GetNodesInTopoOrderStr(); + return oss.str(); +} + +std::string NodeRecomputePlan::NormalizeForNodeClusterId() const { + std::ostringstream oss; + oss << "recompute:" << node->OpType() << "-" + << compromise_recompute_ << "-"; + for (auto& output_index : GetActivationOutputIndices()) { + oss << output_index << ":" << GetTensorElemCountInSymbolicString(node, output_index); + oss << ":" << node->OutputDefs()[output_index]->TypeAsProto()->tensor_type().elem_type() << "-"; + } + + oss << GetNodesInTopoOrderStr(); + return oss.str(); +} + +std::string NodeRecomputePlan::GetNodesInTopoOrderStr() const { + std::string subgraph_str_representation, log_info; + NodesInTopoOrderToString(nodes_in_topological_order_, subgraph_str_representation, log_info); + return subgraph_str_representation; +} + +} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h new file mode 100644 index 0000000000000..9211e5044cd86 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.h @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "orttraining/core/optimizer/memory_optimizer/common.h" +#include "orttraining/core/optimizer/memory_optimizer/optimization_planner.h" + +namespace onnxruntime::optimizer::memory_optimizer { + +/** + * @brief Level to control allowed operations during subgraph detecting. + * Level 0: only allow cheap-to-compute operations. + * Level 1: allow more expensive operations. + */ +enum class ProbeLevel { + Basic = 0, + Advanced = 1, + LevelMax = 2, +}; + +/** + * @brief A child class used for Recompute/RecomputeWithCompromise optimization plan. + * + * For each node generating stashed activations, a recompute plan can be created for it. + */ +class NodeRecomputePlan : public NodeOptimizationPlanBase { + public: + NodeRecomputePlan(const Node* node, + const InlinedVector& activation_output_indices, + const InlinedVector& nodes_in_topological_order, + bool compromise_recompute = false, + float save_ratio = 1.0f) : NodeOptimizationPlanBase(node, activation_output_indices, save_ratio) { + compromise_recompute_ = compromise_recompute; + // Be noted, recompute is node level, each node arg should have the same optimization type. + nodes_in_topological_order_ = nodes_in_topological_order; + } + + const InlinedVector& GetNodesInTopoOrder() const { return nodes_in_topological_order_; } + + bool IsCompromiseRecompute() const { return compromise_recompute_; } + + OptimizationType GetOptimizationType() const override { + return compromise_recompute_ ? OptimizationType::RecomputeWithCompromise + : OptimizationType::Recompute; + } + + /** + * @brief Get the cluster id for this recompute plan. + * The cluster id is used to identify a unique subgraph. + * User can pass such cluster id to enable specific memory optimization for some subgraph. + */ + std::string GetClusterId() const override; + + /** + * @brief Get the serialized string for this recompute plan to create Node-level cluster id. + * Imagine, a Node can have multiple optimization plans, each plan generates its normalization string. + * Once combined we get Node cluster id. + * + * Node cluster id is used to categorize nodes into different groups, showing them as one row in memory + * optimization opportunity table. + */ + std::string NormalizeForNodeClusterId() const override; + + std::string GetNodesInTopoOrderStr() const; + + private: + bool compromise_recompute_; + InlinedVector nodes_in_topological_order_; +}; + +/** + * @brief For the node producing stashed activation, check whether a recomputable subgraph can be found or not. + * + * @param node The entry node to start the subgraph matching (bottom-up), usually the last node of found subgraphs. + * @param probe_level The level to control allowed operations during subgraph detecting. + * @param fw_op_output_arg_used_map The activation usage (in fw and bw) mapping. + * @param node_index_to_its_order_in_topological_sort_map The mapping of node index to its order in topological sort. + * Used to re-order the collected subgraph nodes. + * @param candidate_output_args_map A map from node to its candidate activations, which are consumed by both fw and + * bw ops. + * @param subgraph_stores A store to maintain all found subgraphs. + * @param logger Logger. + * @param compromise_stashed_activation Whether to compromise stashed activation, e.g. if we cannot find a + * recomputable subgraph to save a stashed activation, we can compromise to find a recomputable subgraph to reduce the + * size of stashed activation. + * @param can_compromise_stashed_activation A bool return value, to indicate there is opportunaties for finding a + * compromised subgraph. + */ +std::unique_ptr CheckNodeForRecompute(const Node& node, + const ProbeLevel probe_level, + const ActivationUsedMap& fw_op_output_arg_used_map, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const InlinedHashMap>& + candidate_output_args_map, + const logging::Logger& logger, + bool compromise_stashed_activation, + bool& can_compromise_stashed_activation); + +} // namespace onnxruntime::optimizer::memory_optimizer diff --git a/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc b/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc index dcb3abf2474d3..e719a21118028 100644 --- a/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc +++ b/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc @@ -254,7 +254,9 @@ Status ScaledSumFusion::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve handled_scaled_sum_count += 1; } - LOGS(logger, INFO) << "Total fused ScaledSum node count: " << handled_scaled_sum_count; + if (handled_scaled_sum_count > 0) { + LOGS(logger, INFO) << "Total fused ScaledSum node count: " << handled_scaled_sum_count; + } return Status::OK(); } diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index bb1cb4bbd32f7..a5f46d88e4e8b 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -433,7 +433,20 @@ void addObjectMethodsForTraining(py::module& m) { if (!status.IsOK()) { throw std::runtime_error("Error in backward pass execution: " + status.ErrorMessage()); } - }); + }) + .def("get_serialized_ortmodule_memory_stat", // for memory optimization + [](TrainingAgent* agent, // agent + const std::string& memory_optimization_config, // user config string + const std::string& recompute_probe_level // user config string for probe level + ) -> std::tuple>> { + std::map> cluster_id_combinations_to_saved_symbolic_byte_map; + std::string opportunity_table = + agent->GetSerializedORTModuleMemoryStat(memory_optimization_config, + recompute_probe_level, + cluster_id_combinations_to_saved_symbolic_byte_map); + return std::tuple>>( + opportunity_table, cluster_id_combinations_to_saved_symbolic_byte_map); + }); py::enum_(m, "PropagateCastOpsStrategy", py::module_local(), py::arithmetic{}) .value("NONE", GraphTransformerConfiguration::PropagateCastOpsConfiguration::Strategy::None) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 549614de496a6..a57105545e114 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -53,6 +53,8 @@ def generate_artifacts( 3. Checkpoint (directory): Contains the model parameters. 4. Optimizer model (onnx.ModelProto): Model containing the optimizer graph. + All generated ModelProtos will use the same opsets defined by *model*. + Args: model: The base model to be used for gradient graph generation. requires_grad: List of names of model parameters that require gradient computation @@ -207,11 +209,17 @@ def _export_to_ort_format(model_path, output_dir, extra_options): logging.info("Optimizer enum provided: %s", optimizer.name) + opset_version = None + for domain in model.opset_import: + if domain.domain == "" or domain.domain == "ai.onnx": + opset_version = domain.version + break + optim_model = None optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD} optim_block = optim_blocks[optimizer]() - with onnxblock.empty_base(): + with onnxblock.empty_base(opset_version=opset_version): _ = optim_block(model_params) optim_model = optim_block.to_model_proto() diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py index dc9e0c18eac15..3213a8831ae22 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py @@ -5,6 +5,8 @@ import os +import torch + from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401 from ._slice_scel import slice_scel, slice_scel_backward # noqa: F401 @@ -17,7 +19,12 @@ "slice_scel_backward", ] -if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1: +if ( + "ORTMODULE_USE_FLASH_ATTENTION" in os.environ + and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1 + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 8 +): from ._flash_attn import flash_attn_backward, flash_attn_forward # noqa: F401 _all_kernels.extend(["flash_attn_forward", "flash_attn_backward"]) diff --git a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py index 533fea5a0a721..7a89aadee9950 100644 --- a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py +++ b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from typing import Tuple + import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.capi._pybind_state import TrainingAgent as C_TrainingAgent @@ -161,3 +163,13 @@ def run_backward(self, feeds, fetches, state): :param state: State of the graph that is used for executing partial graph runs. """ self._training_agent.run_backward(feeds, fetches, state) + + def get_serialized_ortmodule_memory_stat( + self, memory_optimization_config: str, recompute_probe_level: str + ) -> Tuple[str, dict]: + """ + Get serialized memory stats for OrtModule. + """ + return self._training_agent.get_serialized_ortmodule_memory_stat( + memory_optimization_config, recompute_probe_level + ) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 5eb1d9f382380..26993dec17ccf 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,7 +19,7 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch_dtype +from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils @@ -91,7 +91,8 @@ def __init__( self._first_skip_check_warning = True # Inspector for runtime information, for example input data, memory usage, etc. - self._runtime_inspector = RuntimeInspector(self._logger) + self._runtime_inspector = RuntimeInspector(self._logger, self._original_module) + self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step) # Tracker for ORTModule model export, session creation overhead. self.time_tracker = _logger.TimeTracker() @@ -242,12 +243,6 @@ def _get_session_config(self): # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. session_options.log_severity_level = int(self._debug_options.logging.log_level) - session_options.add_session_config_entry( - "optimization.enable_memory_optimizer", self._runtime_options.memory_optimizer_config - ) - session_options.add_session_config_entry( - "optimization.enable_memory_probe_recompute_level", self._runtime_options.probe_level - ) # Disable weight prepacking session_options.add_session_config_entry("session.disable_prepacking", "1") @@ -318,7 +313,8 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu """ # VERBOSE -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) - # INFO -> FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) + # DEVINFO -> FULL export verbose log + FULL torch other logs from stdout and stderr (C++ backend) + # INFO -> [Rank 0] FULL export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) # WARNING/ERROR -> [Rank 0] NO export verbose log + FILTERED torch other logs from stdout and stderr (C++ backend) # Be noted: rank 0 log only is controlled by logger configured in _logger.py torch_exporter_verbose_log = self._debug_options.logging.log_level <= LogLevel.INFO @@ -565,7 +561,6 @@ def _enable_conditional_optimizations( enable sparsity-based optimization. """ - # Enable data sparsity inspection if sparse optimizer is ON or user wants to print input density. if self._runtime_options.enable_sparse_optimizer or self._runtime_options.print_input_density: self._runtime_inspector.enable_input_inspector( @@ -612,9 +607,6 @@ def _enable_conditional_optimizations( if not self._runtime_options.print_input_density: self._runtime_inspector.disable_input_inspector() - if self._runtime_options.print_memory_stat: - self._runtime_inspector.enable_memory_inspector(self._original_module) - def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): from ._zero_stage3_compatibility import ( STAGE3_PULL_WEIGHT_TRIGGER_NAME, @@ -634,105 +626,141 @@ def _log_feature_stats(self): if get_rank() != 0: return - feature_map: List[Tuple[str, bool, str]] = [ - ("ATen Executor", True, "Dispatch ATen operators to ORT's ATen executor"), - ( + if self._runtime_inspector.memory_ob.is_enabled() and self._debug_options.log_level <= LogLevel.DEVINFO: + self._logger.info(self._runtime_inspector.memory_ob.memory_optimization_opportunity_table_str) + + tbl = PTable() + + def _add_record(tbl, columns): + return tbl.add_row([columns[0], ":", "ON" if columns[1] else "OFF", ":", columns[2]]) + + notes = [] + + _add_record(tbl, ["ATen Executor", True, "Dispatch ATen operators to ORT's ATen executor"]) + _add_record( + tbl, + [ "Cast Propagation", self._runtime_options.propagate_cast_ops_level > 0, f"Level {self._runtime_options.propagate_cast_ops_level} enabled", - ), - ( + ], + ) + _add_record( + tbl, + [ "Custom Function", self._runtime_options.enable_custom_autograd_function, "Support custom torch.autograd.Function export and execution", - ), - ( - "Memory Optimizer", - len(self._runtime_options.memory_optimizer_config) > 0, - "Enable with env ORTMODULE_MEMORY_OPT_CONFIG=", - ), - ] + ], + ) - # Add compute optimizer - feature_map.extend( + output_memory_optimization_details = self._debug_options.log_level <= LogLevel.INFO + mem_row = _add_record( + tbl, [ + "Memory Optimizer", + len(self._runtime_options.memory_optimizer_config) > 0, ( - "Compute Optimizer", - self._runtime_options.enable_compute_optimizer, - "Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0", - ), - ( - " -FLOPReduction", - self._runtime_options.enable_compute_optimizer, - "Reduce FLOPs by upstreaming shrinking-sized ops", + f"User config: {self._runtime_options.memory_optimizer_config}, probe level: {self._runtime_options.probe_level}" + if len(self._runtime_options.memory_optimizer_config) > 0 + else "Enable with env ORTMODULE_MEMORY_OPT_CONFIG=" ), - ] + ], + ) + + if self._runtime_inspector.memory_ob.is_enabled() and output_memory_optimization_details: + mem_notes, mem_tbl = self._runtime_inspector.memory_ob.display_memory_optimization_plans( + self._runtime_options.memory_optimizer_config + ) + if mem_tbl is not None: + mem_row.append_annotation_table(mem_tbl) + notes.extend(mem_notes) + + _add_record( + tbl, + [ + "Compute Optimizer", + self._runtime_options.enable_compute_optimizer, + "Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0", + ], + ) + _add_record( + tbl, + [ + " - FLOPReduction", + self._runtime_options.enable_compute_optimizer, + "Reduce FLOPs by upstreaming shrinking-sized ops", + ], ) if self._runtime_options.enable_compute_optimizer: if len(self._runtime_options.label_sparsity_ratio) > 0: - feature_map.append( - (" -LabelSparsityOpt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}") + _add_record( + tbl, [" - LabelSparsityOpt", True, f"Input density: {self._runtime_options.label_sparsity_ratio}"] ) if len(self._runtime_options.embed_sparsity_ratio) > 0: - feature_map.append( - (" -EmbedSparsityOpt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}") + _add_record( + tbl, [" - EmbedSparsityOpt", True, f"Input density: {self._runtime_options.embed_sparsity_ratio}"] ) # Add fallback - feature_map.append( - ( + _add_record( + tbl, + [ "Auto Fallback", self._runtime_options.fallback_policy is not _FallbackPolicy.FALLBACK_DISABLE, "Fallback to PyTorch when encountering unsupported ops", - ) + ], ) - if self._runtime_options.enable_triton: - feature_map.append( - ( - "TritonOp Enabled", - True, - "ORT will switch to Triton for executing some ops to further accelerate training.", - ) - ) + # Add Triton + _add_record( + tbl, + [ + "TritonOp Enabled", + self._runtime_options.enable_triton, + "ORT will switch to Triton for executing some ops to further accelerate training.", + ], + ) if self._runtime_options.enable_tuning: desc = "Enable tunning Ops online" if self._runtime_options.tuning_results_path: desc += f", save tuning results to {self._runtime_options.tuning_results_path}" - feature_map.append(("Online Op Tuning", True, desc)) + _add_record(tbl, ["Online Op Tuning", True, desc]) elif self._runtime_options.tuning_results_path: - feature_map.append( - ( + _add_record( + tbl, + [ "Offline Op Tuning", True, f"Use offline tuning results from {self._runtime_options.tuning_results_path}", - ) + ], ) - feature_map.append( - ( + _add_record( + tbl, + [ "ZeRO Stage3 Support", self._runtime_options.enable_zero_stage3_support, "Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0", - ) + ], ) mode = "training" if self._export_mode == torch.onnx.TrainingMode.TRAINING else "inference" mode = f"{_logger.LogColor.UNDERLINE}{mode}{_logger.LogColor.ENDC}" - - stat = f"\n\n{_logger.LogColor.HEADER}***** ONNX Runtime Training (ORTModule) is accelerating your model *****{_logger.LogColor.ENDC}\n\n" + stat = f"\n{_logger.LogColor.HEADER}***** ONNX Runtime Training (ORTModule) is accelerating your model *****{_logger.LogColor.ENDC}\n\n" stat += f"ORTModule is enabled with following features ON/OFF for [{mode}] mode:\n\n" - for feature_tuple in feature_map: - switch_str = "ON" if feature_tuple[1] else "OFF" - stat += f"{feature_tuple[0]:<20}:\t{switch_str:<10}:\t{feature_tuple[2]:<80}\n" + stat += tbl.get_string() + "\n" # Collect ORTModule overheads for different phases. stat += f"\n{self.time_tracker.to_string(self._debug_options.logging.log_level < LogLevel.WARNING)}\n" - stat += f"Versions: ONNX Runtime - {onnxruntime.__version__}, ONNX - {onnx.__version__}\n\n" - stat += f"{_logger.LogColor.HEADER}************************************************************************{_logger.LogColor.ENDC}\n\n" + # Add notes + for index, note in enumerate(notes): + stat += f"Note {index + 1}: {note}\n" + + stat += f"\n{_logger.LogColor.HEADER}************************************************************************{_logger.LogColor.ENDC}\n\n" self._logger.warning(stat) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 1b6e2df9d2e1c..f5fbd5093fca3 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -210,6 +210,7 @@ def _expand_inputs(current_input, non_none_inputs, name=""): result = [] embed_sparsity_results = OrderedDict() label_sparsity_results = OrderedDict() + onnx_input_to_value_map = OrderedDict() for input_idx, name in enumerate(onnx_input_names): inp = None @@ -251,6 +252,8 @@ def _expand_inputs(current_input, non_none_inputs, name=""): if label_density < 100: label_sparsity_results[name] = label_density result.append(inp) + + onnx_input_to_value_map[name] = inp else: raise wrap_exception( ORTModuleONNXModelException, RuntimeError(f"Input is present in ONNX graph but not provided: {name}.") @@ -264,6 +267,10 @@ def _expand_inputs(current_input, non_none_inputs, name=""): else: result.extend(params) + if rt_inspector.memory_ob.is_enabled() and not rt_inspector.memory_ob.symbolic_dim_collecting_completed: + rt_inspector.memory_ob.collect_symbolic_dim_values(input_info.dynamic_axes, onnx_input_to_value_map) + rt_inspector.memory_ob.symbolic_dim_collecting_completed = True + return result, embed_sparsity_results, label_sparsity_results diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index 0728ebdf19af8..a01db28374b8d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -263,7 +263,7 @@ def wrapper(graph_execution_manager, *args, **kwargs): raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.") with _suppress_os_stream_output( - enable=graph_execution_manager._debug_options.log_level >= LogLevel.INFO, + enable=graph_execution_manager._debug_options.log_level >= LogLevel.DEVINFO, on_exit=partial( _log_with_filter, graph_execution_manager._logger, diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index dda909e8cb0f1..cfd2e25e13e26 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -5,12 +5,18 @@ from enum import IntEnum from logging import Logger -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import onnx import torch from onnx import ModelProto, helper from onnx import onnx_pb as onnx_proto +from sympy import Symbol, simplify +from sympy.parsing.sympy_parser import parse_expr + +from onnxruntime.training.utils import PTable + +from ._execution_agent import TrainingAgent class Phase(IntEnum): @@ -39,11 +45,11 @@ class RuntimeInspector: Runtime inspector for ORTModule. """ - def __init__(self, logger: Logger): + def __init__(self, logger: Logger, module: torch.nn.Module): self._logger = logger self.input_density_ob: Union[InputDensityObserver, None] = None - self.memory_ob: Union[MemoryObserver, None] = None + self.memory_ob = MemoryObserver(module, self._logger) def enable_input_inspector(self, model: ModelProto, user_input_names: List[str]) -> None: """Initialize input inspector from the given ONNX model and user input names. @@ -82,26 +88,6 @@ def disable_input_inspector(self) -> None: """Disable input density inspector.""" self.input_density_ob = None - def enable_memory_inspector(self, module: torch.nn.Module): - """Enable memory inspector for ORTModule. - - Args: - module: ORTModule. - """ - if self.memory_ob is None: - self.memory_ob = MemoryObserver(module, self._logger) - else: - raise RuntimeError("Memory observer is already enabled.") - - def inspect_memory(self, phase: Phase) -> None: - """Inspect memory usage and print statistics. - - Args: - phase: Phase to inspect. - """ - if self.memory_ob is not None: - self.memory_ob.inspect_memory(phase) - class InputDensityObserver: """Training input data observer for ORTModule. @@ -460,6 +446,16 @@ def _try_get_initializer_value(self, model, name): return value +class MemoryOptimizationSummary: + """Memory optimization summary for a cluster id combination.""" + + def __init__(self, saving_str="", simplified_saving_expr=None, evaluated_saving=None, freq=0): + self.raw_symbolic_saving_str = saving_str + self.simplified_symbolic_saving_expr: Optional[Symbol] = simplified_saving_expr + self.evaluated_saving: Union[str, int, None] = evaluated_saving + self.freq = freq + + class MemoryObserver: """Memory inspector across the training lifetime. @@ -472,6 +468,19 @@ class MemoryObserver: def __init__(self, m: torch.nn.Module, logger: Logger): self._logger = logger + self._is_enabled = True + + # Memory optimization related. + self.memory_optimization_opportunity_table_str = None + self.cluster_id_combination_to_saving_symbolics_map: Dict[str, MemoryOptimizationSummary] = {} + ## The value is a list of symbolic dim values parsed from the first batch. + self.symbolic_dim_name_to_value_map: Dict = {} + + ## Used to control only the first batch is used to collect symbolic dim values. + self.symbolic_dim_collecting_completed = False + + # For per-step memory inspection. + self._print_memory_stats_by_step = False self._current_step = 0 self._rank = 0 self._world_size = 1 @@ -485,8 +494,77 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._is_first_inspect = True + def is_enabled(self) -> bool: + """Check if memory inspector is enabled.""" + return self._is_enabled + + def enable_memory_stats_by_step(self, print_memory_stats_by_step: bool): + # For per-step memory inspection. + self._print_memory_stats_by_step = print_memory_stats_by_step + + def collect_symbolic_dim_values( + self, + onnx_input_name_to_dynamic_axes_map: Dict[str, Dict[int, str]], + onnx_input_to_value_map: Dict[str, torch.Tensor], + ): + """Collect symbolic dim values.""" + for input_name, dynamic_axes in onnx_input_name_to_dynamic_axes_map.items(): + if input_name in onnx_input_to_value_map: + for dim_idx, dim_name in dynamic_axes.items(): + self.symbolic_dim_name_to_value_map[Symbol(dim_name)] = onnx_input_to_value_map[input_name].size()[ + dim_idx + ] + + def find_memory_optimization_opportunity( + self, execution_agent: TrainingAgent, memory_optimizer_config, probe_level + ): + """Find memory optimization opportunity. + + Args: + execution_agent: TrainingAgent. + memory_optimizer_config: Memory optimization config. + probe_level: Memory probe level. + """ + ( + self.memory_optimization_opportunity_table_str, + memory_optimization_saving_symbolics, + ) = execution_agent.get_serialized_ortmodule_memory_stat(memory_optimizer_config, probe_level) + + cluster_id_to_saving_symbol_map: Dict[str, MemoryOptimizationSummary] = {} + for cluster_id, memory_saving_stat in memory_optimization_saving_symbolics.items(): + memory_saving_symbolic = memory_saving_stat[0] + freq = memory_saving_stat[1] + expr = parse_expr(memory_saving_symbolic) + simplified_expr = simplify(expr) + r = simplified_expr.evalf(subs=self.symbolic_dim_name_to_value_map) + evaluated_saving = None + if r.is_number: + evaluated_saving = float(r) + else: + evaluated_saving = r + + cluster_id_to_saving_symbol_map[cluster_id] = MemoryOptimizationSummary( + memory_saving_symbolic, simplified_expr, evaluated_saving, freq + ) + + # Sorted by evaluated_saving if it is a float + sorted_list = sorted( + cluster_id_to_saving_symbol_map.items(), + key=lambda x: x[1].evaluated_saving if isinstance(x[1].evaluated_saving, float) else 0, + reverse=True, + ) + + for cluster_id, values in sorted_list: + self.cluster_id_combination_to_saving_symbolics_map[cluster_id] = values + def inspect_memory(self, cur_phase: Phase): - if not torch.cuda.is_available(): + """Inspect memory usage and print statistics. + + Args: + phase: Phase to inspect. + """ + + if not torch.cuda.is_available() or not self._print_memory_stats_by_step: return if self._is_first_inspect: @@ -498,36 +576,38 @@ def inspect_memory(self, cur_phase: Phase): if self._rank != 0: return - if cur_phase < Phase.PRE_FORWARD or cur_phase > self._last_phase: - raise RuntimeError(f"Invalid phase detected: {cur_phase}") + if cur_phase < Phase.PRE_FORWARD or (cur_phase <= self._last_phase): + raise RuntimeError(f"Invalid phase detected: {cur_phase}, last_phase: {self._last_phase}") if (cur_phase - self._pre_phase) != 1: raise RuntimeError(f"Invalid phase transition detected: {self._pre_phase} -> {cur_phase}") - cur_mem_allocated = self._normalize(torch.cuda.memory_allocated()) - max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated()) - cur_mem_cached = self._normalize(torch.cuda.memory_reserved()) - max_mem_cached = self._normalize(torch.cuda.max_memory_reserved()) - torch_mem_stat = torch.cuda.memory_stats() - cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) - max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) - - mem_stats = [ - ["phase", _convert_phase_to_string(cur_phase)], - ["allocated", cur_mem_allocated], # current memory alloeated for tensors - ["max allocated", max_mem_allocated], # peak memory allocated for tensors - ["cached", cur_mem_cached], # current memory cached for caching allocator - ["max cached", max_mem_cached], # peak memory cached for caching allocator. - ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory - ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory - ] - - summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})" - for stat in mem_stats: - summ += f" | {stat[0]}: {stat[1]}" - # For the 10+ steps, only print when it is power of 2. - if self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0): + need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0) + + if need_print: + cur_mem_allocated = self._normalize(torch.cuda.memory_allocated()) + max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated()) + cur_mem_cached = self._normalize(torch.cuda.memory_reserved()) + max_mem_cached = self._normalize(torch.cuda.max_memory_reserved()) + torch_mem_stat = torch.cuda.memory_stats() + cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) + max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) + + mem_stats = [ + ["phase", _convert_phase_to_string(cur_phase)], + ["allocated", cur_mem_allocated], # current memory allocated for tensors + ["max allocated", max_mem_allocated], # peak memory allocated for tensors + ["cached", cur_mem_cached], # current memory cached for the caching allocator + ["max cached", max_mem_cached], # peak memory cached for caching allocator. + ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory + ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory + ] + + summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})" + for stat in mem_stats: + summ += f" | {stat[0]}: {stat[1]}" + self._logger.info(summ) if cur_phase == self._last_phase: @@ -542,3 +622,72 @@ def _increase_step(self): def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str: return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}" + + def display_memory_optimization_plans(self, memory_optimizer_config) -> Tuple[List[str], PTable]: + mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) + + if mem_plan_count > 0: + mem_tbl = PTable() + mem_tbl.add_row(["", "", "", "", "Configs", "Freq", "Max Saving(Bytes)", "Saving Symbolic(Bytes)"]) + + index = 1 + + def _get_user_config_without_freq(configs: str): + if len(configs) == 0: + return [] + config_list = configs.split(",") + configs_with_out_freq = [] + for config in config_list: + config_values = config.split(":") + freq = int(config_values[2]) + if freq == 0: + continue + configs_with_out_freq.append(config_values[0] + ":" + config_values[1]) + + return configs_with_out_freq + + user_configs_with_out_freq = _get_user_config_without_freq(memory_optimizer_config) + + for ( + cluster_id, + saving_symbolic, + ) in self.cluster_id_combination_to_saving_symbolics_map.items(): + saving_bytes = saving_symbolic.evaluated_saving + if isinstance(saving_bytes, float): + saving_bytes = f"{saving_bytes:,.0f}" + + cluster_ids_without_freq = _get_user_config_without_freq(cluster_id) + + mem_tbl.add_row( + [ + f" - Plan {index}", + ":", + "ON" + if all(cluster_id in user_configs_with_out_freq for cluster_id in cluster_ids_without_freq) + else "OFF", + ":", + cluster_id, + saving_symbolic.freq, + saving_bytes, + saving_symbolic.simplified_symbolic_saving_expr, + ] + ) + + index += 1 + + saving_recommendation = ( + "use comma as delimiter to enable multiple memory optimization plans at the same time:\n" + ) + saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=,,..." + + notes = [] + notes.append(saving_recommendation) + + saving_recommendation = "memory saving is calculated based on the 1st batch symbolic dim values:\n" + for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): + saving_recommendation += f" {dim_param}={dim_value}," + notes.append(saving_recommendation) + + return notes, mem_tbl + + return [], None diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index bafb64235546b..96a95557bb9a1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -18,7 +18,7 @@ from ._gradient_accumulation_manager import GradientAccumulationManager from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo from ._io import _FlattenedModule, _InputInfo, unflatten_user_output -from ._logger import ORTModuleInitPhase, TrackTime +from ._logger import LogLevel, ORTModuleInitPhase, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results from .graph_optimizer_registry import GraphOptimizerRegistry @@ -111,7 +111,7 @@ def forward(ctx, *inputs): Module outputs are returned to the user """ - self._runtime_inspector.inspect_memory(Phase.PRE_FORWARD) + self._runtime_inspector.memory_ob.inspect_memory(Phase.PRE_FORWARD) if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: # Assert that the input and model device match @@ -146,7 +146,7 @@ def forward(ctx, *inputs): for idx in self._graph_info.output_grad_indices_non_differentiable: ctx.mark_non_differentiable(user_outputs[idx]) - self._runtime_inspector.inspect_memory(Phase.POST_FORWARD) + self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_FORWARD) return user_outputs @@ -154,7 +154,7 @@ def forward(ctx, *inputs): def backward(ctx, *grad_outputs): """Performs backward pass based on grad wrt module output""" - self._runtime_inspector.inspect_memory(Phase.PRE_BACKWARD) + self._runtime_inspector.memory_ob.inspect_memory(Phase.PRE_BACKWARD) assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()" if self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: @@ -205,7 +205,7 @@ def backward(ctx, *grad_outputs): # This version only works if backward_outputs is an OrtValueVector. transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - self._runtime_inspector.inspect_memory(Phase.POST_BACKWARD) + self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) @@ -242,7 +242,6 @@ def forward(self, *inputs, **kwargs): self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT), self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE), ) - # If exporting module to ONNX for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. build_gradient_graph = False @@ -433,6 +432,39 @@ def _create_execution_agent(self): local_device_rank = self._device.index if device_type == "ort" else _utils.get_device_index(self._device) + # When log level is <= INFO, we would collect memory optimization opportunities. + # (TODO: consider to enable by default once memory optimization feature is stable and well improved.) + # Create a training agent without enabling memory optimization here is beneficial for memory analyzing + # when we have an allocation plan in place, and reuse information is available. + if self._runtime_inspector.memory_ob.is_enabled() and self._debug_options.log_level <= LogLevel.INFO: + # Create a training agent without enabling memory optimization. + execution_agent = TrainingAgent( + self._onnx_models.optimized_model.SerializeToString(), + fw_feed_names, + fw_outputs_device_info, + bw_fetches_names, + bw_outputs_device_info, + session_options, + providers, + provider_options, + local_device_rank, + ) + + self._runtime_inspector.memory_ob.find_memory_optimization_opportunity( + execution_agent, self._runtime_options.memory_optimizer_config, self._runtime_options.probe_level + ) + + # Release it as early as possible. + del execution_agent + + # Enable memory optimization if it is enabled in the session options. + session_options.add_session_config_entry( + "optimization.memory_optimizer_config", self._runtime_options.memory_optimizer_config + ) + session_options.add_session_config_entry( + "optimization.enable_memory_probe_recompute_level", self._runtime_options.probe_level + ) + self._execution_agent = TrainingAgent( self._onnx_models.optimized_model.SerializeToString(), fw_feed_names, diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py index d215e12f8137a..3d3538a62da61 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -5,9 +5,16 @@ import os +import torch +from packaging.version import Version + _all_optimizers = [] -if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1: +if ( + "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ + and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1 + and Version(torch.__version__) >= Version("2.1.1") +): from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 _all_optimizers.append("optimize_graph_for_aten_efficient_attention") diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py index 94bd41293b427..b1e8809f03fc0 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -245,31 +245,25 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro ("MatMul", False, []), # 0 ("Mul", True, [(0, 0, 0)]), # 1 ("Mul", True, [(0, 0, 1)]), # 2 - ("Cast", True, [(1, 0, 0)]), # 3 - ("Cast", True, [(2, 0, 0)]), # 4 - ("Transpose", True, [(3, 0, 0)]), # 5 - ("Transpose", True, [(4, 0, 0)]), # 6 - ("Softmax", False, [(0, 0, 0)]), # 7 - ("Cast", False, [(7, 0, 0)]), # 8 - ("MatMul", False, [(8, 0, 0)]), # 9 - ("Transpose", True, [(9, 0, 1)]), # 10 - ("Transpose", False, [(9, 0, 0)]), # 11 - ("FusedMatMul", False, [(10, 0, 1)]), # 12 - ("Cast", False, [(12, 0, 0)]), # 13 - ("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14 - ("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15 - ("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16 - ("Mul", False, [(15, 0, 0)]), # 17 - ("Mul", False, [(16, 0, 0)]), # 18 - ("Identity", False, [(17, 0, 0)]), # 19 - ("Identity", False, [(18, 0, 0)]), # 20 - ("Cast", False, [(19, 0, 0)]), # 21 - ("Cast", False, [(20, 0, 0)]), # 22 - ("Transpose", False, [(21, 0, 0)]), # 23 - ("Transpose", False, [(22, 0, 0)]), # 24 - ("FusedMatMul", False, [(8, 0, 0)]), # 25 - ("Transpose", True, [(25, 0, 1)]), # 26 - ("Transpose", False, [(25, 0, 0)]), # 27 + ("Transpose", True, [(1, 0, 0)]), # 3 + ("Transpose", True, [(2, 0, 0)]), # 4 + ("Softmax", False, [(0, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11 + ("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12 + ("Mul", False, [(11, 0, 0)]), # 13 + ("Mul", False, [(12, 0, 0)]), # 14 + ("Identity", False, [(13, 0, 0)]), # 15 + ("Identity", False, [(14, 0, 0)]), # 16 + ("Transpose", False, [(15, 0, 0)]), # 17 + ("Transpose", False, [(16, 0, 0)]), # 18 + ("FusedMatMul", False, [(5, 0, 0)]), # 19 + ("Transpose", True, [(19, 0, 1)]), # 20 + ("Transpose", False, [(19, 0, 0)]), # 21 ] @@ -280,27 +274,24 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 if not ( - check_attribute_value(nodes[3], "to", 1) - and check_attribute_value(nodes[4], "to", 1) - and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[8], "to", 10) - and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3]) + check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) and scale_value_1 == scale_value_2 ): return [], [], [] nodes_to_add, new_value_infos = _make_efficient_attention_nodes( idx, - nodes[5].input[0], - nodes[6].input[0], - nodes[10].input[0], - nodes[11].output[0], - nodes[26].input[0], - nodes[23].output[0], - nodes[24].output[0], - nodes[27].output[0], + nodes[3].input[0], + nodes[4].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[20].input[0], + nodes[17].output[0], + nodes[18].output[0], + nodes[21].output[0], "", False, scale_value_1, @@ -315,39 +306,32 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro ("MatMul", False, []), # 0 ("Mul", True, [(0, 0, 0)]), # 1 ("Mul", True, [(0, 0, 1)]), # 2 - ("Cast", True, [(1, 0, 0)]), # 3 - ("Cast", True, [(2, 0, 0)]), # 4 - ("Transpose", True, [(3, 0, 0)]), # 5 - ("Transpose", True, [(4, 0, 0)]), # 6 - ("Add", False, [(0, 0, 0)]), # 7 - ("Cast", True, [(7, 0, 1)]), # 8 - ("Slice", True, [(8, 0, 0)]), # 9 - ("Slice", True, [(9, 0, 0)]), # 10 - ("Unsqueeze", True, [(9, 0, 2)]), # 11 - ("Gather", True, [(11, 0, 0)]), # 12 - ("Shape", True, [(12, 0, 0)]), # 13 - ("Softmax", False, [(7, 0, 0)]), # 14 - ("Cast", False, [(14, 0, 0)]), # 15 - ("MatMul", False, [(15, 0, 0)]), # 16 - ("Transpose", True, [(16, 0, 1)]), # 17 - ("Transpose", False, [(16, 0, 0)]), # 18 - ("FusedMatMul", False, [(17, 0, 1)]), # 19 - ("Cast", False, [(19, 0, 0)]), # 20 - ("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21 - ("Identity", False, [(21, 0, 0)]), # 22 - ("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23 - ("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24 - ("Mul", False, [(23, 0, 0)]), # 25 - ("Mul", False, [(24, 0, 0)]), # 26 - ("Identity", False, [(25, 0, 0)]), # 27 - ("Identity", False, [(26, 0, 0)]), # 28 - ("Cast", False, [(27, 0, 0)]), # 29 - ("Cast", False, [(28, 0, 0)]), # 30 - ("Transpose", False, [(29, 0, 0)]), # 31 - ("Transpose", False, [(30, 0, 0)]), # 32 - ("FusedMatMul", False, [(15, 0, 0)]), # 33 - ("Transpose", True, [(33, 0, 1)]), # 34 - ("Transpose", False, [(33, 0, 0)]), # 35 + ("Transpose", True, [(1, 0, 0)]), # 3 + ("Transpose", True, [(2, 0, 0)]), # 4 + ("Add", False, [(0, 0, 0)]), # 5 + ("Slice", True, [(5, 0, 1)]), # 6 + ("Slice", True, [(6, 0, 0)]), # 7 + ("Unsqueeze", True, [(6, 0, 2)]), # 8 + ("Gather", True, [(8, 0, 0)]), # 9 + ("Shape", True, [(9, 0, 0)]), # 10 + ("Softmax", False, [(5, 0, 0)]), # 11 + ("MatMul", False, [(11, 0, 0)]), # 12 + ("Transpose", True, [(12, 0, 1)]), # 13 + ("Transpose", False, [(12, 0, 0)]), # 14 + ("FusedMatMul", False, [(13, 0, 1)]), # 15 + ("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16 + ("Identity", False, [(16, 0, 0)]), # 17 + ("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18 + ("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19 + ("Mul", False, [(18, 0, 0)]), # 20 + ("Mul", False, [(19, 0, 0)]), # 21 + ("Identity", False, [(20, 0, 0)]), # 22 + ("Identity", False, [(21, 0, 0)]), # 23 + ("Transpose", False, [(22, 0, 0)]), # 24 + ("Transpose", False, [(23, 0, 0)]), # 25 + ("FusedMatMul", False, [(11, 0, 0)]), # 26 + ("Transpose", True, [(26, 0, 1)]), # 27 + ("Transpose", False, [(26, 0, 0)]), # 28 ] @@ -358,27 +342,24 @@ def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodePro scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 if not ( - check_attribute_value(nodes[3], "to", 1) - and check_attribute_value(nodes[4], "to", 1) - and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[15], "to", 10) - and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3]) + check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3]) and scale_value_1 == scale_value_2 ): return [], [], [] nodes_to_add, new_value_infos = _make_efficient_attention_nodes( idx, - nodes[5].input[0], - nodes[6].input[0], - nodes[17].input[0], - nodes[18].output[0], - nodes[34].input[0], - nodes[31].output[0], - nodes[32].output[0], - nodes[35].output[0], + nodes[3].input[0], + nodes[4].input[0], + nodes[13].input[0], + nodes[14].output[0], + nodes[27].input[0], + nodes[24].output[0], + nodes[25].output[0], + nodes[28].output[0], "", False, scale_value_1, diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index cddd9cd440b28..77022f86d3ff3 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -137,7 +137,7 @@ def logging(self): def torch_exporter_filter(self): """Accessor for the filter export logs configuration.""" torch_version = get_runtime_pytorch_version() - if self.log_level >= LogLevel.INFO: + if self.log_level > LogLevel.DEVINFO: if torch_version < version.parse("2.0"): return [ # WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. @@ -262,7 +262,7 @@ def __init__(self, logger: Logger): # Configuration for dev tools. self.print_input_density = False - self.print_memory_stat = False + self.print_memory_stat_by_step = False # Configuration for fallback. self.fallback_policy = ortmodule.ORTMODULE_FALLBACK_POLICY @@ -321,7 +321,7 @@ def _override_from_env_vars(self): if "ORTMODULE_PRINT_INPUT_DENSITY" in os.environ: self.print_input_density = int(os.getenv("ORTMODULE_PRINT_INPUT_DENSITY")) == 1 if "ORTMODULE_PRINT_MEMORY_STATS" in os.environ: - self.print_memory_stat = int(os.getenv("ORTMODULE_PRINT_MEMORY_STATS")) == 1 + self.print_memory_stat_by_step = int(os.getenv("ORTMODULE_PRINT_MEMORY_STATS")) == 1 # Configuration for fallback. if "ORTMODULE_FALLBACK_POLICY" in os.environ: diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index d40a6ddf7daf3..244557c3c1072 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. # __init__.py +from onnxruntime.training.utils.ptable import PTable from onnxruntime.training.utils.torch_io_helper import ( ORTModelInputOutputSchemaType, ORTModelInputOutputType, @@ -24,4 +25,5 @@ "pytorch_type_to_onnx_dtype", "onnx_dtype_to_pytorch_dtype", "pytorch_scalar_type_to_pytorch_dtype", + "PTable", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 0d268a7a4a5cf..61f3b20224a72 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -291,7 +291,7 @@ def backward(ctx, *grads): raise RuntimeError(f"param {p} has no grad, this should not happen.") # Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch. assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}" - p.backward(g) + # p.backward(g) # At this point, the **real** param grads are already updated, the following grads are only used for # completing the full backward propagation, will not affect parameter updates. diff --git a/orttraining/orttraining/python/training/utils/ptable.py b/orttraining/orttraining/python/training/utils/ptable.py new file mode 100644 index 0000000000000..3b3b80d29ed92 --- /dev/null +++ b/orttraining/orttraining/python/training/utils/ptable.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import List + + +class Row: + """A row in a PTable""" + + def __init__(self, columns: List[str]) -> None: + self._columns: List[str] = columns # List of strings + self._annotation_table = None # Optional PTable used for displaying detailed information about the feature row. + + def append_annotation_table(self, ptable) -> None: + self._annotation_table = ptable + + +class PTable: + """A table that can be printed to the console.""" + + def __init__(self) -> None: + self._rows: List[Row] = [] + self._column_count = None + + def add_row(self, columns: List[str]) -> Row: + """Add a row to the table. The number of columns must match the number of columns in the table.""" + if self._column_count is None: + self._column_count = len(columns) + assert self._column_count == len(columns) + row = Row(columns) + self._rows.append(row) + return row + + def get_string(self, first_column_width=None, second_column_width=None) -> str: + """Serialize the table to a string.""" + # Collect the max width of each column + column_widths = [] + for row in self._rows: + if column_widths: + assert len(column_widths) == len(row._columns) + else: + column_widths = [0] * len(row._columns) + for i, column in enumerate(row._columns): + column_widths[i] = max(column_widths[i], len(str(column))) + + if first_column_width: + column_widths[0] = max(first_column_width, column_widths[0]) + + if second_column_width: + column_widths[2] = max(second_column_width, column_widths[2]) + + serialized_table = "" + for row in self._rows: + for i, column in enumerate(row._columns): + serialized_table += f"{str(column).ljust(column_widths[i] + 2)}" + serialized_table += "\n" + if row._annotation_table: + serialized_table += row._annotation_table.get_string( + first_column_width=column_widths[0], second_column_width=column_widths[2] + ) + + return serialized_table diff --git a/orttraining/orttraining/test/gradient/optimizer_ops_test.cc b/orttraining/orttraining/test/gradient/optimizer_ops_test.cc index c100730aacc44..bfb59f1525e47 100644 --- a/orttraining/orttraining/test/gradient/optimizer_ops_test.cc +++ b/orttraining/orttraining/test/gradient/optimizer_ops_test.cc @@ -1542,7 +1542,6 @@ TEST(OptimizerTest, LambOptimizerTestLarge) { std::vector m(size); std::vector v(size); - std::random_device random_device; std::mt19937 random_engine(0); std::uniform_real_distribution dist(0.1f, 1.0f); for (int i = 0; i < size; ++i) { @@ -1581,7 +1580,6 @@ TEST(OptimizerTest, LambOptimizerTestLarge) { TEST(OptimizerTest, LambOptimizerMultiTensorRatio) { constexpr int group_count = 127; - std::random_device random_device; std::mt19937 random_engine(0); std::uniform_real_distribution dist(0.1f, 1.0f); std::uniform_int_distribution dist_int(1, 1228); diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index f7a7220dd66ea..6e5d54cbb9427 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -17,6 +17,14 @@ # PyTorch Module definitions +def get_opsets_model(filename): + if isinstance(filename, onnx.ModelProto): + onx = filename + else: + onx = onnx.load(filename) + return {d.domain: d.version for d in onx.opset_import} + + class SimpleNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super().__init__() @@ -999,3 +1007,13 @@ def test_save_ort_format(): assert os.path.exists(os.path.join(temp_dir, "eval_model.ort")) assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "optimizer_model.ort")) + base_opsets = get_opsets_model(base_model) + training_opsets = get_opsets_model(os.path.join(temp_dir, "training_model.onnx")) + eval_opsets = get_opsets_model(os.path.join(temp_dir, "eval_model.onnx")) + optimizer_opsets = get_opsets_model(os.path.join(temp_dir, "optimizer_model.onnx")) + if base_opsets[""] != training_opsets[""]: + raise AssertionError(f"Opsets mismatch {base_opsets['']} != {training_opsets['']}.") + if base_opsets[""] != eval_opsets[""]: + raise AssertionError(f"Opsets mismatch {base_opsets['']} != {eval_opsets['']}.") + if base_opsets[""] != optimizer_opsets[""]: + raise AssertionError(f"Opsets mismatch {base_opsets['']} != {optimizer_opsets['']}.") diff --git a/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc b/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc index be8b0aaa0bce1..60c3ecbcce8ce 100644 --- a/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/reduction/reduction_ops_test.cc @@ -275,7 +275,6 @@ void TestMultiTensorReduce( test.SetDeterminism(use_determinism); // Set up random number generator. - std::random_device random_device; std::mt19937 random_engine(0); std::uniform_real_distribution dist(min, max); std::uniform_int_distribution dist_int(min_tensor_size, max_tensor_size); diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 6f492317524be..8ea0481c9b101 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -35,6 +35,9 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("HIPBLAS_OP_T", "rocblas_operation_transpose") s = s.replace("HIPBLAS_OP_N", "rocblas_operation_none") + # in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want + s = s.replace("rocblas_half", "__half") + s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels") s = s.replace("cudaEvent", "hipEvent") s = s.replace("CreateCudaAllocator", "CreateRocmAllocator") diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 3b1a0317c58f1..76cda428cabe3 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -378,8 +378,9 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--gdk_platform", default="Scarlett", help="Sets the GDK target platform.") parser.add_argument("--ios", action="store_true", help="build for ios") + parser.add_argument( - "--ios_sysroot", default="", help="Specify the location name of the macOS platform SDK to be used" + "--apple_sysroot", default="", help="Specify the location name of the macOS platform SDK to be used" ) parser.add_argument( "--ios_toolchain_file", @@ -1273,33 +1274,38 @@ def generate_build_tree( if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] - if args.ios: + if args.build_apple_framework or args.ios: if not args.cmake_generator == "Xcode": - raise BuildError("iOS build requires use of the Xcode CMake generator ('--cmake_generator Xcode').") + raise BuildError( + "iOS/MacOS framework build requires use of the Xcode CMake generator ('--cmake_generator Xcode')." + ) needed_args = [ - args.ios_sysroot, + args.apple_sysroot, args.apple_deploy_target, ] arg_names = [ - "--ios_sysroot " + "", + "--apple_sysroot " + "", "--apple_deploy_target " + "", ] if not all(needed_args): raise BuildError( - "iOS build on MacOS canceled due to missing arguments: " + "iOS/MacOS framework build on MacOS canceled due to missing arguments: " + ", ".join(val for val, cond in zip(arg_names, needed_args) if not cond) ) cmake_args += [ - "-DCMAKE_SYSTEM_NAME=iOS", "-Donnxruntime_BUILD_SHARED_LIB=ON", - "-DCMAKE_OSX_SYSROOT=" + args.ios_sysroot, + "-DCMAKE_OSX_SYSROOT=" + args.apple_sysroot, "-DCMAKE_OSX_DEPLOYMENT_TARGET=" + args.apple_deploy_target, # we do not need protoc binary for ios cross build "-Dprotobuf_BUILD_PROTOC_BINARIES=OFF", - "-DCMAKE_TOOLCHAIN_FILE=" - + (args.ios_toolchain_file if args.ios_toolchain_file else "../cmake/onnxruntime_ios.toolchain.cmake"), ] + if args.ios: + cmake_args += [ + "-DCMAKE_SYSTEM_NAME=iOS", + "-DCMAKE_TOOLCHAIN_FILE=" + + (args.ios_toolchain_file if args.ios_toolchain_file else "../cmake/onnxruntime_ios.toolchain.cmake"), + ] if args.build_wasm: emsdk_dir = os.path.join(cmake_dir, "external", "emsdk") @@ -1761,10 +1767,10 @@ def run_ios_tests(args, source_dir, config, cwd): ) if args.build_apple_framework: - package_test_py = os.path.join(source_dir, "tools", "ci_build", "github", "apple", "test_ios_packages.py") + package_test_py = os.path.join(source_dir, "tools", "ci_build", "github", "apple", "test_apple_packages.py") framework_info_file = os.path.join(cwd, "framework_info.json") - dynamic_framework_dir = os.path.join(cwd, config + "-" + args.ios_sysroot) - static_framework_dir = os.path.join(cwd, config + "-" + args.ios_sysroot, "static_framework") + dynamic_framework_dir = os.path.join(cwd, config + "-" + args.apple_sysroot) + static_framework_dir = os.path.join(cwd, config + "-" + args.apple_sysroot, "static_framework") # test dynamic framework run_subprocess( [ @@ -1774,6 +1780,8 @@ def run_ios_tests(args, source_dir, config, cwd): dynamic_framework_dir, "--framework_info_file", framework_info_file, + "--variant", + "Mobile", ], cwd=cwd, ) @@ -1786,6 +1794,8 @@ def run_ios_tests(args, source_dir, config, cwd): static_framework_dir, "--framework_info_file", framework_info_file, + "--variant", + "Mobile", ], cwd=cwd, ) diff --git a/tools/ci_build/github/apple/assemble_ios_packaging_artifacts.sh b/tools/ci_build/github/apple/assemble_apple_packaging_artifacts.sh similarity index 100% rename from tools/ci_build/github/apple/assemble_ios_packaging_artifacts.sh rename to tools/ci_build/github/apple/assemble_apple_packaging_artifacts.sh diff --git a/tools/ci_build/github/apple/build_and_assemble_ios_pods.py b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py similarity index 82% rename from tools/ci_build/github/apple/build_and_assemble_ios_pods.py rename to tools/ci_build/github/apple/build_and_assemble_apple_pods.py index d3443e6cb0f4d..006dc4c33ffce 100755 --- a/tools/ci_build/github/apple/build_and_assemble_ios_pods.py +++ b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py @@ -32,13 +32,13 @@ def parse_args(): parser.add_argument( "--build-dir", type=pathlib.Path, - default=REPO_DIR / "build" / "ios_framework", + default=REPO_DIR / "build" / "apple_framework", help="The build directory. This will contain the iOS framework build output.", ) parser.add_argument( "--staging-dir", type=pathlib.Path, - default=REPO_DIR / "build" / "ios_pod_staging", + default=REPO_DIR / "build" / "apple_pod_staging", help="The staging directory. This will contain the iOS pod package files. " "The pod package files do not have dependencies on files in the build directory.", ) @@ -60,20 +60,20 @@ def parse_args(): build_framework_group = parser.add_argument_group( title="iOS framework build arguments", - description="See the corresponding arguments in build_ios_framework.py for details.", + description="See the corresponding arguments in build_apple_framework.py for details.", ) build_framework_group.add_argument("--include-ops-by-config") build_framework_group.add_argument( - "--build-settings-file", required=True, help="The positional argument of build_ios_framework.py." + "--build-settings-file", required=True, help="The positional argument of build_apple_framework.py." ) build_framework_group.add_argument( "-b", - "--build-ios-framework-arg", + "--build-apple-framework-arg", action="append", - dest="build_ios_framework_extra_args", + dest="build_apple_framework_extra_args", default=[], - help="Pass an argument through to build_ios_framework.py. This may be specified multiple times.", + help="Pass an argument through to build_apple_framework.py. This may be specified multiple times.", ) args = parser.parse_args() @@ -101,27 +101,27 @@ def main(): # build framework package_variant = PackageVariant[args.variant] - framework_info_file = build_dir / "framework_info.json" + framework_info_file = build_dir / "xcframework_info.json" - log.info("Building iOS framework.") + log.info("Building Apple framework.") - build_ios_framework_args = [ + build_apple_framework_args = [ sys.executable, - str(SCRIPT_DIR / "build_ios_framework.py"), - *args.build_ios_framework_extra_args, + str(SCRIPT_DIR / "build_apple_framework.py"), + *args.build_apple_framework_extra_args, ] if args.include_ops_by_config is not None: - build_ios_framework_args += ["--include_ops_by_config", args.include_ops_by_config] + build_apple_framework_args += ["--include_ops_by_config", args.include_ops_by_config] - build_ios_framework_args += ["--build_dir", str(build_dir), args.build_settings_file] + build_apple_framework_args += ["--build_dir", str(build_dir), args.build_settings_file] - run(build_ios_framework_args) + run(build_apple_framework_args) if args.test: - test_ios_packages_args = [ + test_apple_packages_args = [ sys.executable, - str(SCRIPT_DIR / "test_ios_packages.py"), + str(SCRIPT_DIR / "test_apple_packages.py"), "--fail_if_cocoapods_missing", "--framework_info_file", str(framework_info_file), @@ -131,7 +131,7 @@ def main(): package_variant.name, ] - run(test_ios_packages_args) + run(test_apple_packages_args) # assemble pods and then move them to their target locations (staging_dir/) staging_dir.mkdir(parents=True, exist_ok=True) diff --git a/tools/ci_build/github/apple/build_ios_framework.py b/tools/ci_build/github/apple/build_apple_framework.py similarity index 81% rename from tools/ci_build/github/apple/build_ios_framework.py rename to tools/ci_build/github/apple/build_apple_framework.py index 7983581f07fd6..5137a0644b2e7 100644 --- a/tools/ci_build/github/apple/build_ios_framework.py +++ b/tools/ci_build/github/apple/build_apple_framework.py @@ -30,19 +30,17 @@ def _parse_build_settings(args): build_settings["build_osx_archs"] = build_settings_data.get("build_osx_archs", DEFAULT_BUILD_OSX_ARCHS) - build_params = [] if "build_params" in build_settings_data: - build_params += build_settings_data["build_params"] + build_settings["build_params"] = build_settings_data["build_params"] else: raise ValueError("build_params is required in the build config file") - build_settings["build_params"] = build_params return build_settings # Build fat framework for all archs of a single sysroot # For example, arm64 and x86_64 for iphonesimulator -def _build_for_ios_sysroot( +def _build_for_apple_sysroot( build_config, intermediates_dir, base_build_command, sysroot, archs, build_dynamic_framework ): # paths of the onnxruntime libraries for different archs @@ -54,7 +52,7 @@ def _build_for_ios_sysroot( build_dir_current_arch = os.path.join(intermediates_dir, sysroot + "_" + current_arch) build_command = [ *base_build_command, - "--ios_sysroot=" + sysroot, + "--apple_sysroot=" + sysroot, "--osx_arch=" + current_arch, "--build_dir=" + build_dir_current_arch, ] @@ -103,6 +101,20 @@ def _build_for_ios_sysroot( return framework_dir +def _merge_framework_info_files(files, output_file): + merged_data = {} + + for file in files: + with open(file) as f: + data = json.load(f) + for platform, values in data.items(): + assert platform not in merged_data, f"Duplicate platform value: {platform}" + merged_data[platform] = values + + with open(output_file, "w") as f: + json.dump(merged_data, f, indent=2) + + def _build_package(args): build_settings = _parse_build_settings(args) build_dir = os.path.abspath(args.build_dir) @@ -110,20 +122,26 @@ def _build_package(args): # Temp dirs to hold building results intermediates_dir = os.path.join(build_dir, "intermediates") build_config = args.config - base_build_command = [sys.executable, BUILD_PY] + build_settings["build_params"] + ["--config=" + build_config] - - if args.include_ops_by_config is not None: - base_build_command += ["--include_ops_by_config=" + str(args.include_ops_by_config.resolve())] - - if args.path_to_protoc_exe is not None: - base_build_command += ["--path_to_protoc_exe=" + str(args.path_to_protoc_exe.resolve())] # build framework for individual sysroot framework_dirs = [] - framework_info_path = "" + framework_info_files_to_merge = [] public_headers_path = "" for sysroot in build_settings["build_osx_archs"]: - framework_dir = _build_for_ios_sysroot( + base_build_command = ( + [sys.executable, BUILD_PY] + + build_settings["build_params"]["base"] + + build_settings["build_params"][sysroot] + + ["--config=" + build_config] + ) + + if args.include_ops_by_config is not None: + base_build_command += ["--include_ops_by_config=" + str(args.include_ops_by_config.resolve())] + + if args.path_to_protoc_exe is not None: + base_build_command += ["--path_to_protoc_exe=" + str(args.path_to_protoc_exe.resolve())] + + framework_dir = _build_for_apple_sysroot( build_config, intermediates_dir, base_build_command, @@ -132,17 +150,20 @@ def _build_package(args): args.build_dynamic_framework, ) framework_dirs.append(framework_dir) - # podspec and headers for each sysroot are the same, pick one of them - if not framework_info_path: - framework_info_path = os.path.join(os.path.dirname(framework_dir), "framework_info.json") + + curr_framework_info_path = os.path.join(os.path.dirname(framework_dir), "framework_info.json") + framework_info_files_to_merge.append(curr_framework_info_path) + + # headers for each sysroot are the same, pick one of them + if not public_headers_path: public_headers_path = os.path.join(os.path.dirname(framework_dir), "onnxruntime.framework", "Headers") - # create the folder for xcframework and copy the LICENSE and podspec file + # create the folder for xcframework and copy the LICENSE and framework_info.json file xcframework_dir = os.path.join(build_dir, "framework_out") pathlib.Path(xcframework_dir).mkdir(parents=True, exist_ok=True) shutil.copy(os.path.join(REPO_DIR, "LICENSE"), xcframework_dir) shutil.copytree(public_headers_path, os.path.join(xcframework_dir, "Headers"), dirs_exist_ok=True) - shutil.copy(framework_info_path, build_dir) + _merge_framework_info_files(framework_info_files_to_merge, os.path.join(build_dir, "xcframework_info.json")) # remove existing xcframework if any xcframework_path = os.path.join(xcframework_dir, "onnxruntime.xcframework") @@ -171,7 +192,7 @@ def parse_args(): parser.add_argument( "--build_dir", type=pathlib.Path, - default=os.path.join(REPO_DIR, "build/iOS_framework"), + default=os.path.join(REPO_DIR, "build/apple_framework"), help="Provide the root directory for build output", ) diff --git a/tools/ci_build/github/apple/c/assemble_c_pod_package.py b/tools/ci_build/github/apple/c/assemble_c_pod_package.py index 14e7729610617..1d7647dd469db 100644 --- a/tools/ci_build/github/apple/c/assemble_c_pod_package.py +++ b/tools/ci_build/github/apple/c/assemble_c_pod_package.py @@ -28,8 +28,6 @@ def get_pod_config_file(package_variant: PackageVariant): return _script_dir / "onnxruntime-c.config.json" elif package_variant == PackageVariant.Mobile: return _script_dir / "onnxruntime-mobile-c.config.json" - elif package_variant == PackageVariant.Test: - return _script_dir / "onnxruntime-test-c.config.json" elif package_variant == PackageVariant.Training: return _script_dir / "onnxruntime-training-c.config.json" else: @@ -49,7 +47,7 @@ def assemble_c_pod_package( :param staging_dir Path to the staging directory for the C/C++ pod files. :param pod_version C/C++ pod version. - :param framework_info_file Path to the framework_info.json file containing additional values for the podspec. + :param framework_info_file Path to the framework_info.json or xcframework_info.json file containing additional values for the podspec. :param public_headers_dir Path to the public headers directory to include in the pod. :param framework_dir Path to the onnxruntime framework directory to include in the pod. :param package_variant The pod package variant. @@ -77,14 +75,16 @@ def assemble_c_pod_package( # generate the podspec file from the template variable_substitutions = { "DESCRIPTION": pod_config["description"], - "IOS_DEPLOYMENT_TARGET": framework_info["IOS_DEPLOYMENT_TARGET"], + # By default, we build both "iphoneos" and "iphonesimulator" architectures, and the deployment target should be the same between these two. + "IOS_DEPLOYMENT_TARGET": framework_info["iphonesimulator"]["APPLE_DEPLOYMENT_TARGET"], + "MACOSX_DEPLOYMENT_TARGET": framework_info.get("macosx", {}).get("APPLE_DEPLOYMENT_TARGET", ""), "LICENSE_FILE": "LICENSE", "NAME": pod_name, "ORT_C_FRAMEWORK": framework_dir.name, "ORT_C_HEADERS_DIR": public_headers_dir.name, "SUMMARY": pod_config["summary"], "VERSION": pod_version, - "WEAK_FRAMEWORK": framework_info["WEAK_FRAMEWORK"], + "WEAK_FRAMEWORK": framework_info["iphonesimulator"]["WEAK_FRAMEWORK"], } podspec_template = _script_dir / "c.podspec.template" @@ -114,7 +114,7 @@ def parse_args(): "--framework-info-file", type=pathlib.Path, required=True, - help="Path to the framework_info.json file containing additional values for the podspec. " + help="Path to the framework_info.json or xcframework_info.json file containing additional values for the podspec. " "This file should be generated by CMake in the build directory.", ) parser.add_argument( diff --git a/tools/ci_build/github/apple/c/c.podspec.template b/tools/ci_build/github/apple/c/c.podspec.template index e0cbfe23608fc..a04f20b359229 100644 --- a/tools/ci_build/github/apple/c/c.podspec.template +++ b/tools/ci_build/github/apple/c/c.podspec.template @@ -6,7 +6,13 @@ Pod::Spec.new do |spec| spec.homepage = "https://github.com/microsoft/onnxruntime" spec.source = { :http => "file:///http_source_placeholder" } spec.summary = "@SUMMARY@" - spec.platform = :ios, "@IOS_DEPLOYMENT_TARGET@" + spec.ios.deployment_target = "@IOS_DEPLOYMENT_TARGET@" + + macosx_deployment_target = "@MACOSX_DEPLOYMENT_TARGET@" + if macosx_deployment_target != "" + spec.osx.deployment_target = macosx_deployment_target + end + spec.vendored_frameworks = "@ORT_C_FRAMEWORK@" spec.static_framework = true spec.weak_framework = [ @WEAK_FRAMEWORK@ ] diff --git a/tools/ci_build/github/apple/c/onnxruntime-test-c.config.json b/tools/ci_build/github/apple/c/onnxruntime-test-c.config.json deleted file mode 100644 index d55dbc63e057c..0000000000000 --- a/tools/ci_build/github/apple/c/onnxruntime-test-c.config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "name": "onnxruntime-test-c", - "summary": "TEST POD", - "description": "Pod for testing. Not for actual release." -} diff --git a/tools/ci_build/github/apple/coreml_supported_ops.md b/tools/ci_build/github/apple/coreml_supported_ops.md index 959177bcb4d7b..e2e43587ab674 100644 --- a/tools/ci_build/github/apple/coreml_supported_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_ops.md @@ -34,6 +34,8 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Shape|Attribute `start` with non-default value is not supported.
Attribute `end` is not supported.| |ai.onnx:Sigmoid|| |ai.onnx:Slice|Inputs `starts`, `ends`, `axes`, and `steps` should be constant. Empty slice is not supported.| +|ai.onnx:Softmax|| +|ai.onnx:Split|If provided, `splits` should be constant. num of outputs supported is at least 2.| |ai.onnx:Squeeze|| |ai.onnx:Sqrt|| |ai.onnx:Sub|| diff --git a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json new file mode 100644 index 0000000000000..86b4efdc63750 --- /dev/null +++ b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json @@ -0,0 +1,37 @@ +{ + "build_osx_archs": { + "iphoneos": [ + "arm64" + ], + "iphonesimulator": [ + "arm64", + "x86_64" + ], + "macosx": [ + "arm64", + "x86_64" + ] + }, + "build_params": { + "base": [ + "--parallel", + "--use_xcode", + "--build_apple_framework", + "--use_coreml", + "--use_xnnpack", + "--skip_tests", + "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF" + ], + "macosx": [ + "--apple_deploy_target=11.0" + ], + "iphoneos": [ + "--ios", + "--apple_deploy_target=12.0" + ], + "iphonesimulator": [ + "--ios", + "--apple_deploy_target=12.0" + ] + } +} diff --git a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json deleted file mode 100644 index 621af55fad7fa..0000000000000 --- a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "build_osx_archs": { - "iphoneos": [ - "arm64" - ], - "iphonesimulator": [ - "arm64", - "x86_64" - ] - }, - "build_params": [ - "--ios", - "--parallel", - "--use_xcode", - "--build_apple_framework", - "--use_coreml", - "--use_xnnpack", - "--skip_tests", - "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF", - "--apple_deploy_target=12.0" - ] -} diff --git a/tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json index 2738a7ca7b009..2bdf8de24f53c 100644 --- a/tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json @@ -8,19 +8,27 @@ "x86_64" ] }, - "build_params": [ - "--ios", - "--parallel", - "--use_xcode", - "--build_apple_framework", - "--minimal_build=extended", - "--disable_rtti", - "--disable_ml_ops", - "--disable_exceptions", - "--enable_reduced_operator_type_support", - "--use_coreml", - "--skip_tests", - "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF", - "--apple_deploy_target=12.0" - ] + "build_params": { + "base": [ + "--parallel", + "--use_xcode", + "--build_apple_framework", + "--minimal_build=extended", + "--disable_rtti", + "--disable_ml_ops", + "--disable_exceptions", + "--enable_reduced_operator_type_support", + "--use_coreml", + "--skip_tests", + "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF" + ], + "iphoneos": [ + "--ios", + "--apple_deploy_target=12.0" + ], + "iphonesimulator": [ + "--ios", + "--apple_deploy_target=12.0" + ] + } } diff --git a/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json index ec7fcafce04f2..f88934cd44a66 100644 --- a/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json @@ -6,18 +6,33 @@ "iphonesimulator": [ "arm64", "x86_64" + ], + "macosx": [ + "arm64", + "x86_64" ] }, - "build_params": [ - "--ios", - "--parallel", - "--use_xcode", - "--enable_training_apis", - "--build_apple_framework", - "--use_coreml", - "--use_xnnpack", - "--skip_tests", - "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF", - "--apple_deploy_target=12.0" - ] + "build_params": { + "base": [ + "--parallel", + "--use_xcode", + "--enable_training_apis", + "--build_apple_framework", + "--use_coreml", + "--use_xnnpack", + "--skip_tests", + "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF" + ], + "iphoneos": [ + "--ios", + "--apple_deploy_target=12.0" + ], + "iphonesimulator": [ + "--ios", + "--apple_deploy_target=12.0" + ], + "macosx": [ + "--apple_deploy_target=11.0" + ] + } } diff --git a/tools/ci_build/github/apple/framework_info.json.template b/tools/ci_build/github/apple/framework_info.json.template index 788e52302b3f1..b4c4fb8d16ebf 100644 --- a/tools/ci_build/github/apple/framework_info.json.template +++ b/tools/ci_build/github/apple/framework_info.json.template @@ -1,4 +1,6 @@ { - "IOS_DEPLOYMENT_TARGET": "@CMAKE_OSX_DEPLOYMENT_TARGET@", - "WEAK_FRAMEWORK": "@APPLE_WEAK_FRAMEWORK@" -} \ No newline at end of file + "@CMAKE_OSX_SYSROOT@": { + "APPLE_DEPLOYMENT_TARGET": "@CMAKE_OSX_DEPLOYMENT_TARGET@", + "WEAK_FRAMEWORK": "@APPLE_WEAK_FRAMEWORK@" + } +} diff --git a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py index 135a55165beda..ec1feaae82175 100755 --- a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py +++ b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py @@ -119,7 +119,7 @@ def assemble_objc_pod_package( :param staging_dir Path to the staging directory for the Objective-C pod files. :param pod_version Objective-C pod version. - :param framework_info_file Path to the framework_info.json file containing additional values for the podspec. + :param framework_info_file Path to the framework_info.json or xcframework_info.json file containing additional values for the podspec. :param package_variant The pod package variant. :return Tuple of (package name, path to the podspec file). """ @@ -153,7 +153,7 @@ def path_patterns_as_variable_value(patterns: list[str]): "C_POD_NAME": c_pod_config["name"], "DESCRIPTION": pod_config["description"], "INCLUDE_DIR_LIST": path_patterns_as_variable_value(include_dirs), - "IOS_DEPLOYMENT_TARGET": framework_info["IOS_DEPLOYMENT_TARGET"], + "IOS_DEPLOYMENT_TARGET": framework_info["iphonesimulator"]["APPLE_DEPLOYMENT_TARGET"], "LICENSE_FILE": license_file, "NAME": pod_name, "PUBLIC_HEADER_FILE_LIST": path_patterns_as_variable_value(pod_files["public_header_files"]), @@ -191,7 +191,7 @@ def parse_args(): "--framework-info-file", type=pathlib.Path, required=True, - help="Path to the framework_info.json file containing additional values for the podspec. " + help="Path to the framework_info.json or xcframework_info.json file containing additional values for the podspec. " "This file should be generated by CMake in the build directory.", ) parser.add_argument( diff --git a/tools/ci_build/github/apple/package_assembly_utils.py b/tools/ci_build/github/apple/package_assembly_utils.py index e5940774c54f9..bdf359df1dbb8 100644 --- a/tools/ci_build/github/apple/package_assembly_utils.py +++ b/tools/ci_build/github/apple/package_assembly_utils.py @@ -17,7 +17,6 @@ class PackageVariant(enum.Enum): Full = 0 # full ORT build with all opsets, ops, and types Mobile = 1 # minimal ORT build with reduced ops Training = 2 # full ORT build with all opsets, ops, and types, plus training APIs - Test = -1 # for testing purposes only @classmethod def release_variant_names(cls): diff --git a/tools/ci_build/github/apple/test_ios_packages.py b/tools/ci_build/github/apple/test_apple_packages.py similarity index 87% rename from tools/ci_build/github/apple/test_ios_packages.py rename to tools/ci_build/github/apple/test_apple_packages.py index ff42e9615483a..6dc4868dac8a3 100644 --- a/tools/ci_build/github/apple/test_ios_packages.py +++ b/tools/ci_build/github/apple/test_apple_packages.py @@ -19,7 +19,7 @@ REPO_DIR = SCRIPT_PATH.parents[4] -def _test_ios_packages(args): +def _test_apple_packages(args): # check if CocoaPods is installed if shutil.which("pod") is None: if args.fail_if_cocoapods_missing: @@ -58,10 +58,10 @@ def _test_ios_packages(args): os.makedirs(stage_dir) # assemble the test project here - target_proj_path = stage_dir / "ios_package_test" + target_proj_path = stage_dir / "apple_package_test" # copy the test project source files to target_proj_path - test_proj_path = pathlib.Path(REPO_DIR, "onnxruntime/test/platform/ios/ios_package_test") + test_proj_path = pathlib.Path(REPO_DIR, "onnxruntime/test/platform/apple/apple_package_test") shutil.copytree(test_proj_path, target_proj_path) # assemble local pod files here @@ -133,7 +133,7 @@ def _test_ios_packages(args): "xcodebuild", "test", "-workspace", - "./ios_package_test.xcworkspace", + "./apple_package_test.xcworkspace", "-scheme", "ios_package_test", "-destination", @@ -144,6 +144,24 @@ def _test_ios_packages(args): cwd=target_proj_path, ) + if PackageVariant[args.variant] != PackageVariant.Mobile: + subprocess.run( + [ + "xcrun", + "xcodebuild", + "test", + "-workspace", + "./apple_package_test.xcworkspace", + "-scheme", + "macos_package_test", + "-destination", + "platform=macos", + ], + shell=False, + check=True, + cwd=target_proj_path, + ) + def parse_args(): parser = argparse.ArgumentParser( @@ -161,7 +179,7 @@ def parse_args(): "--framework_info_file", type=pathlib.Path, required=True, - help="Path to the framework_info.json file containing additional values for the podspec. " + help="Path to the framework_info.json or xcframework_info.json file containing additional values for the podspec. " "This file should be generated by CMake in the build directory.", ) @@ -172,7 +190,7 @@ def parse_args(): parser.add_argument( "--variant", choices=PackageVariant.all_variant_names(), - default=PackageVariant.Test.name, + required=True, help="Pod package variant.", ) @@ -193,7 +211,7 @@ def parse_args(): def main(): args = parse_args() - _test_ios_packages(args) + _test_apple_packages(args) if __name__ == "__main__": diff --git a/tools/ci_build/github/apple/use_ios_pods_with_custom_build.md b/tools/ci_build/github/apple/use_ios_pods_with_custom_build.md index c01f0796db0fb..c8da2eff57c33 100644 --- a/tools/ci_build/github/apple/use_ios_pods_with_custom_build.md +++ b/tools/ci_build/github/apple/use_ios_pods_with_custom_build.md @@ -2,9 +2,9 @@ If you require a custom build of ONNX Runtime, you can create CocoaPods pods with your custom build locally and use them from a Podfile. -**Prerequisite** - The custom build must be able to be done with [build_ios_framework.py](./build_ios_framework.py). +**Prerequisite** - The custom build must be able to be done with [build_apple_framework.py](./build_apple_framework.py). -To do a custom build and create the pods, run [build_and_assemble_ios_pods.py](./build_and_assemble_ios_pods.py). +To do a custom build and create the pods, run [build_and_assemble_apple_pods.py](./build_and_assemble_apple_pods.py). Use the `--help` argument to see more information. ## Example usage @@ -15,7 +15,7 @@ Our custom build will use a custom reduced operator kernel config file: `/path/t Run the script: ```bash -python3 tools/ci_build/github/apple/build_and_assemble_ios_pods.py \ +python3 tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ --staging-dir /path/to/staging/dir \ --include-ops-by-config /path/to/custom.config \ --build-settings-file tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 0eccd71e47f46..67fa78da003a3 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,6 +60,14 @@ parameters: type: string default: '--use_azure' +- name: CudaVersion + displayName: CUDA version + type: string + default: '11.8' + values: + - 11.8 + - 12.2 + resources: repositories: - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step @@ -146,7 +154,13 @@ stages: timeoutInMinutes: 120 pool: 'Onnxruntime-Linux-GPU' variables: - CUDA_VERSION: '11.8' + - name: CUDA_VERSION_MAJOR + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: '11' + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: '12' + - name: CUDA_VERSION + value: ${{ parameters.CudaVersion }} steps: - template: templates/set-version-number-variables-step.yml - template: templates/get-docker-image-steps.yml @@ -154,7 +168,7 @@ stages: Dockerfile: tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile Context: tools/ci_build/github/linux/docker/inference/x64/default/gpu DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecuda11centosbuild + Repository: onnxruntimecuda$(CUDA_VERSION_MAJOR)build - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_cuda_c_api_package.sh workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml new file mode 100644 index 0000000000000..8a9592282cd46 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -0,0 +1,175 @@ +parameters: + - name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + + - name: UseIncreasedTimeoutForTests + displayName: Increase timeout for tests? Set it to false if you are doing an Onnx Runtime release. + type: boolean + default: false + + - name: DoCompliance + displayName: Run Compliance Tasks? + type: boolean + default: true + + - name: DoEsrp + displayName: Run code sign tasks? Must be true if you are doing an ONNX Runtime release + type: boolean + default: true + + - name: IsReleaseBuild + displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. + type: boolean + default: false + + - name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + default: none + + - name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + default: 0 + + # these 2 parameters are used for debugging. + - name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean + default: false + + - name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + + - name: CudaVersion + displayName: CUDA version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +variables: + - name: ReleaseVersionSuffix + value: '' + - name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 + - name: linux_trt_version + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: 8.6.1.6-1.cuda11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 + - name: win_trt_home + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0 + - name: win_cuda_home + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: $(Agent.TempDirectory)\v12.2 +resources: + repositories: + - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step + type: github + endpoint: ort-examples + name: microsoft/onnxruntime-inference-examples + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +stages: +# Set ReleaseVersionSuffix + - stage: Set_ReleaseVersionSuffix + jobs: + - job: Set_Variables + pool: + vmImage: ubuntu-latest + steps: + - checkout: none + - bash: | + # Do not output ##vso[] commands with `set -x` or they may be parsed again and include a trailing quote. + set +x + if [[ "${{ parameters.IsReleaseBuild }}" = True && "${{ parameters.PreReleaseVersionSuffixString }}" != "none" ]]; then + if [[ "${{ parameters.PreReleaseVersionSuffixNumber }}" -eq 0 ]]; then + echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]-${{ parameters.PreReleaseVersionSuffixString }}" + else + echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]-${{ parameters.PreReleaseVersionSuffixString }}.${{ parameters.PreReleaseVersionSuffixNumber }}" + fi + else + echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]" + fi + name: Set_Release_Version_Suffix + - bash: echo $(ReleaseVersionSuffix) + name: Debug_Release_Version_Suffix + # this is needed for certain artifacts to be published + - stage: Linux_C_API_Packaging_CPU_x64 + dependsOn: [ ] + jobs: + - template: templates/c-api-linux-cpu.yml + parameters: + BaseImage: 'registry.access.redhat.com/ubi8/ubi' + OnnxruntimeArch: 'x64' + OnnxruntimeCFlags: '-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all' + OnnxruntimeCXXFlags: '-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all' + OnnxruntimeNodejsBindingArch: 'x64' + PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PackageJava: false + PackageNodeJS: false + # Nuget Packaging + + - template: stages/nuget-linux-cuda-packaging-stage.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + docker_base_image: ${{ variables.docker_base_image }} + linux_trt_version: ${{ variables.linux_trt_version }} + - template: stages/nuget-win-cuda-packaging-stage.yml + parameters: + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} + CudaVersion: ${{ parameters.CudaVersion }} + win_trt_home: ${{ variables.win_trt_home }} + win_cuda_home: ${{ variables.win_cuda_home }} + - template: stages/nuget-combine-cuda-stage.yml + parameters: + DoCompliance: ${{ parameters.DoCompliance }} + DoEsrp: ${{ parameters.DoEsrp }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + # Testing + ## Windows GPU Testing + - template: nuget/templates/test_win.yml + parameters: + AgentPool: 'onnxruntime-Win2022-GPU-T4' + NugetPackageName: 'Microsoft.ML.OnnxRuntime.Gpu' + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' + Skipx86Tests: 'true' + CudaVersion: ${{ parameters.CudaVersion }} + ## Linux GPU Testing + - template: nuget/templates/test_linux.yml + parameters: + AgentPool: Onnxruntime-Linux-GPU + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' + NugetPackageName: 'Microsoft.ML.OnnxRuntime.Gpu' + SpecificArtifact: ${{ parameters.specificArtifact }} + CudaVersion: ${{ parameters.CudaVersion }} + BuildId: ${{ parameters.BuildId }} + +## Win/Linux GPU Combined Publishing +#- template: templates/publish-nuget.yml diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 9e1fae343c84e..0993a81a02249 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -26,7 +26,14 @@ pr: - 'js/web' - 'onnxruntime/core/providers/js' #### end trigger #### - +parameters: + - name: CudaVersion + displayName: CUDA version + type: string + default: '11.8' + values: + - 11.8 + - 12.2 resources: repositories: - repository: manylinux @@ -37,6 +44,17 @@ resources: variables: - template: templates/common-variables.yml + - name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 + + - name: linux_trt_version + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: 8.6.1.6-1.cuda11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 jobs: - job: Linux_Build @@ -55,15 +73,14 @@ jobs: - checkout: self clean: true submodules: none - - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimecuda11build @@ -163,8 +180,8 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimecuda11build diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 517c8d638c935..4ca11a4d1565b 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -26,7 +26,14 @@ pr: - 'js/web' - 'onnxruntime/core/providers/js' #### end trigger #### - +parameters: + - name: CudaVersion + displayName: CUDA version + type: string + default: '11.8' + values: + - 11.8 + - 12.2 resources: repositories: - repository: manylinux @@ -34,7 +41,17 @@ resources: endpoint: Microsoft name: pypa/manylinux ref: 5eda9aded5462201e6310105728d33016e637ea7 - +variables: + - name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 + - name: linux_trt_version + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: 8.6.1.6-1.cuda11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 jobs: - job: Linux_Build timeoutInMinutes: 180 @@ -61,8 +78,8 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg BASEIMAGE=${{ variables.docker_base_image }} + --build-arg TRT_VERSION=${{ variables.linux_trt_version }} --build-arg BUILD_UID=$( id -u ) " Repository: onnxruntimetensorrt86gpubuild @@ -99,7 +116,8 @@ jobs: --build_shared_lib \ --parallel \ --build_wheel \ - --enable_onnx_tests --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \ + --enable_onnx_tests \ + --use_cuda --cuda_home=/usr/local/cuda-${{ parameters.CudaVersion }} --cudnn_home=/usr/local/cuda-${{ parameters.CudaVersion }} \ --enable_pybind --build_java \ --use_tensorrt --tensorrt_home /usr \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 \ diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml index b1d7ede2843c8..18d53654e7c4d 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml @@ -54,7 +54,7 @@ jobs: --use_coreml \ --use_xnnpack \ --ios \ - --ios_sysroot iphonesimulator \ + --apple_sysroot iphonesimulator \ --osx_arch x86_64 \ --apple_deploy_target 12.0 \ --use_xcode \ diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index 64fa29f06553e..1e609b052b8d3 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -7,7 +7,7 @@ parameters: SpecificArtifact: false CustomOpArtifactName: 'onnxruntime-linux-x64' BuildId: '0' - + CudaVersion: '11.8' stages: - stage: NuGet_Test_Linux_${{ parameters.StageSuffix }} dependsOn: @@ -54,9 +54,18 @@ stages: - ${{if contains(parameters.StageSuffix , 'GPU') }}: - template: ../../templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu Context: tools/ci_build/github/linux/docker/ - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + ${{ if eq(parameters.CudaVersion, '12.2') }}: + DockerBuildArgs: " + --build-arg BASEIMAGE=nvidia/cuda:12.2.2-cudnn8-devel-ubuntu20.04 + --build-arg TRT_VERSION=8.6.1.6-1+cuda12.0 + --build-arg BUILD_UID=$( id -u ) + " + ${{ else }}: + DockerBuildArgs: " + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimepackagestest - bash: | docker run --rm \ diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml index 0b9ded10ddd3e..a15c3061913f8 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml @@ -3,11 +3,12 @@ parameters: NugetPackageName : '' ArtifactSuffix: '' StageSuffix: 'CPU' - # For inference packages, the test data artifact name is drop-nuget and no suffix is required. + # For inference packages, the test data artifact name is drop-extra and no suffix is required. # For training packages, to differentiate the artifact name we add '-training' suffix. This needs to be passed from # the parent pipeline. TestDataArtifactSuffix: '' Skipx86Tests: 'false' + CudaVersion: '' stages: - stage: NuGet_Test_Win_${{ parameters.StageSuffix }} @@ -27,6 +28,10 @@ stages: value: 'ON' - name: runCodesignValidationInjection value: false + - name: CUDA_MODULE_LOADINGL + value: 'LAZY' + - name: GRADLE_OPTS + value: '-Dorg.gradle.daemon=false' steps: - task: UsePythonVersion@0 @@ -39,13 +44,12 @@ stages: displayName: Use Nuget 5.7.0 inputs: versionSpec: 5.7.0 - - - task: BatchScript@1 - displayName: 'setup env' - inputs: - filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\setup_env_gpu.bat' - modifyEnvironment: true - workingFolder: '$(Build.BinariesDirectory)' + - ${{ if ne( parameters.CudaVersion, '') }}: + - template: ../../templates/jobs/download_win_gpu_library.yml + parameters: + DownloadCUDA: true + DownloadTRT: true + CudaVersion: ${{ parameters.CudaVersion }} - task: BatchScript@1 displayName: 'Setup Visual Studio env vars' @@ -60,12 +64,6 @@ stages: artifactName: drop-signed-nuget-${{ parameters.ArtifactSuffix }} targetPath: '$(Build.BinariesDirectory)\nuget-artifact' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - testdata' - inputs: - artifactName: 'drop-nuget${{ parameters.TestDataArtifactSuffix }}' - targetPath: '$(Build.BinariesDirectory)\testdata' - - template: get-nuget-package-version-as-variable.yml parameters: packageFolder: '$(Build.BinariesDirectory)\nuget-artifact' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml new file mode 100644 index 0000000000000..422fb33eec5de --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml @@ -0,0 +1,22 @@ +trigger: none + +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +stages: +- template: templates/py-packaging-training-cuda-stage.yml + parameters: + build_py_parameters: --enable_training --update --build + torch_version: '2.1.0' + opset_version: '15' + cuda_version: '12.2' + cmake_cuda_architectures: 70;75;80;86;90 + docker_file: Dockerfile.manylinux2_28_training_cuda12_2 + agent_pool: Onnxruntime-Linux-GPU + upload_wheel: 'yes' + debug_build: false diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 6fdb255606a19..c86920422b6f0 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -286,14 +286,15 @@ stages: displayName: "Install Python requirements" - script: | - python tools/ci_build/github/apple/build_ios_framework.py \ + python tools/ci_build/github/apple/build_apple_framework.py \ --build_dir "$(Build.BinariesDirectory)/ios_framework" \ --build_dynamic_framework \ tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json displayName: "Build iOS dynamic framework" - script: | - python tools/ci_build/github/apple/test_ios_packages.py \ - --framework_info_file "$(Build.BinariesDirectory)/ios_framework/framework_info.json" \ - --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" + python tools/ci_build/github/apple/test_apple_packages.py \ + --framework_info_file "$(Build.BinariesDirectory)/ios_framework/xcframework_info.json" \ + --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \ + --variant Mobile displayName: "Test pod with iOS dynamic framework" diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml index aee42d3675087..91179d141498b 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml @@ -31,7 +31,7 @@ resources: ref: 5eda9aded5462201e6310105728d33016e637ea7 stages: - - template: stages/py-cuda-packaging-stage.yml + - template: stages/py-nuget-combine-cuda-stage.yml parameters: enable_linux_gpu: ${{ parameters.enable_linux_gpu }} enable_windows_gpu: ${{ parameters.enable_windows_gpu }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml new file mode 100644 index 0000000000000..b69e75856c39f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -0,0 +1,228 @@ +parameters: +- name: DoCompliance + type: boolean + default: true + +- name: DoEsrp + type: boolean + default: true + +- name: IsReleaseBuild + type: boolean + default: false + +stages: +######## Nuget ######## +# Win/Linux CUDA Combined packaging +- stage: NuGet_Packaging_GPU + dependsOn: + - Set_ReleaseVersionSuffix + - Windows_Packaging_gpu + - Windows_Packaging_tensorrt + - Linux_C_API_Packaging_CPU_x64 + - Linux_C_API_Packaging_GPU_x64 + - Linux_C_API_Packaging_GPU_TensorRT_x64 + condition: succeeded() + jobs: + - job: + workspace: + clean: all + # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. + # VS2019 has no support for net6 and we need to use msbuild (from the VS install) to do the packing + pool: 'Azure-Pipelines-EO-Windows2022-aiinfra' + variables: + breakCodesignValidationInjection: ${{ parameters.DoEsrp }} + ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + + steps: + - checkout: self + submodules: true + # Download the all artifacts + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact from Linux_C_API_Packaging_GPU_x64 Stage' + inputs: + artifactName: 'onnxruntime-win-x64-cuda' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact from Linux_C_API_Packaging_GPU_TensorRT_x64 Stage' + inputs: + artifactName: 'onnxruntime-win-x64-tensorrt' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact from Windows_Packaging_gpu Stage' + inputs: + artifactName: 'onnxruntime-linux-x64-cuda' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact from Windows_Packaging_tensorrt Stage' + inputs: + artifactName: 'onnxruntime-linux-x64-tensorrt' + targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact - protoc from Windows_Packaging_(cpu|gpu) Stage' + inputs: + artifactName: 'drop-extra' + targetPath: '$(Build.BinariesDirectory)/extra-artifact' + + # Reconstruct the build dir + - task: PowerShell@2 + displayName: 'PS: Extract nuget files gpu' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\extract_nuget_files_gpu.ps1 + + - script: | + dir + workingDirectory: '$(Build.BinariesDirectory)/nuget-artifact' + displayName: 'List artifacts' + + - script: | + mklink /D /J models C:\local\models + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Create models link' + + - task: NuGetToolInstaller@0 + displayName: Use Nuget 6.2.1 + inputs: + versionSpec: 6.2.1 + + - task: PowerShell@2 + displayName: Install .NET 6 workloads + inputs: + targetType: 'inline' + script: | + dotnet workload install android ios macos + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: PowerShell@2 + displayName: Build .NET 6 targets using dotnet + inputs: + targetType: 'inline' + # we don't specify 'Any CPU' as the platform here because if we do it gets added to the output path + # e.g. csharp\src\Microsoft.ML.OnnxRuntime\bin\Any CPU\RelWithDebInfo\net6.0-ios\ + # which is inconsistent with the msbuild output path for the pre-.net6 targets + # e.g. csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\monoandroid11.0 + # and makes it harder to do the packing + # + # 'Any CPU' is the default (first 'mixed' platform specified in the csproj) so this should be fine. + script: | + dotnet build .\src\Microsoft.ML.OnnxRuntime\Microsoft.ML.OnnxRuntime.csproj -p:SelectedTargets=Net6 -p:Configuration=RelWithDebInfo -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Restore NuGet Packages and create project.assets.json for pre-.net6 targets' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:restore -p:SelectedTargets=PreNet6 -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Build C# for pre-.net6 targets' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + configuration: RelWithDebInfo + platform: 'Any CPU' + msbuildArguments: '-p:SelectedTargets=PreNet6 -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - template: ../templates/win-esrp-dll.yml + parameters: + FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' + DisplayName: 'ESRP - Sign C# dlls' + DoEsrp: ${{ parameters.DoEsrp }} + + - task: MSBuild@1 + displayName: Update projects.assets.json with combined list of all target frameworks + inputs: + solution: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\Microsoft.ML.OnnxRuntime.csproj' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:restore -p:SelectedTargets=All -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: MSBuild@1 + displayName: 'Build Nuget Packages' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' + configuration: RelWithDebInfo + platform: 'Any CPU' + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: BatchScript@1 + displayName: 'Add TensorRT header file to the native nuGet package' + inputs: + filename: $(Build.SourcesDirectory)\tools\ci_build\github\windows\bundle_nuget_with_native_headers.bat + workingFolder: $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + Contents: '*.snupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: CopyFiles@2 + displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' + Contents: '*.nupkg' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - template: ../templates/esrp_nuget.yml + parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)' + DoEsrp: ${{ parameters.DoEsrp }} + + - template: ../templates/validate-package.yml + parameters: + PackageType: 'nuget' + PackagePath: '$(Build.ArtifactStagingDirectory)' + PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' + PlatformsSupported: 'win-x64,linux-x64' + VerifyNugetSigning: false + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-GPU' + targetPath: '$(Build.ArtifactStagingDirectory)' + + + - task: MSBuild@1 + displayName: 'Clean C#' + inputs: + solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + platform: 'Any CPU' + configuration: RelWithDebInfo + msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + + - task: RoslynAnalyzers@2 + displayName: 'Run Roslyn Analyzers' + inputs: + userProvideBuildInfo: msBuildInfo + msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\msbuild.exe" $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln -p:configuration="RelWithDebInfo" -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' + condition: and(succeeded(), eq('${{ parameters.DoCompliance }}', true)) + + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml new file mode 100644 index 0000000000000..140a377ca72a3 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -0,0 +1,161 @@ +parameters: +- name: CudaVersion + type: string + default: '11.8' +- name: docker_base_image + type: string +- name: linux_trt_version + type: string + +stages: + # Linux CUDA without TensorRT Packaging +- stage: Linux_C_API_Packaging_GPU_x64 + dependsOn: [] + jobs: + - job: + workspace: + clean: all + timeoutInMinutes: 120 + pool: 'Onnxruntime-Linux-GPU' + variables: + - name: CUDA_VERSION_MAJOR + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: '11' + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: '12' + - name: CUDA_VERSION + value: ${{ parameters.CudaVersion }} + steps: + - template: ../templates/set-version-number-variables-step.yml + - template: ../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/x64/default/gpu + DockerBuildArgs: " + --build-arg BUILD_UID=$( id -u ) + --build-arg BASEIMAGE=${{ parameters.docker_base_image }} + " + Repository: onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}build + + - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_cuda_c_api_package.sh + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Build and Test' + + - template: ../templates/c-api-artifacts-package-and-publish-steps-posix.yml + parameters: + buildConfig: 'Release' + artifactName: 'onnxruntime-linux-x64-cuda-$(OnnxRuntimeVersion)' + artifactNameNoVersionString: 'onnxruntime-linux-x64-cuda' + libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' + + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + - template: ../templates/clean-agent-build-directory-step.yml +# Linux CUDA with TensorRT Packaging +- template: ../templates/linux-gpu-tensorrt-packaging-pipeline.yml + parameters: + artifactName: 'onnxruntime-linux-x64-tensorrt-$(OnnxRuntimeVersion)' + artifactNameNoVersionString: 'onnxruntime-linux-x64-tensorrt' + buildJava: false + buildJavaOption: '--build_java' + buildNodejs: false + buildNodejsOption: '--build_nodejs' + CudaVersion: ${{ parameters.CudaVersion }} +# Linux CUDA Combined Testing and Publishing +- stage: Linux_Packaging_combined_GPU + dependsOn: + - Linux_C_API_Packaging_GPU_x64 + - Linux_C_API_Packaging_GPU_TensorRT_x64 + condition: succeeded() + jobs: + - job: + workspace: + clean: all + pool: 'Onnxruntime-Linux-GPU' + + steps: + - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime + submodules: false + - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples + submodules: false + - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux + submodules: false + + - script: | + set -e -x + cd $(Build.SourcesDirectory) + mv manylinux onnxruntime + ls + + - template: ../templates/with-container-registry-steps.yml + parameters: + Steps: + - script: | + tools/ci_build/get_docker_image.py \ + --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda \ + --context tools/ci_build/github/linux/docker \ + --docker-build-args "--network=host --build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ parameters.linux_trt_version }} --build-arg BUILD_UID=$( id -u )" \ + --container-registry onnxruntimebuildcache \ + --multiple_repos \ + --repository onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build + displayName: "Get onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build image for tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda" + workingDirectory: $(Build.SourcesDirectory)/onnxruntime + ContainerRegistry: onnxruntimebuildcache + + - template: ../templates/set-version-number-variables-step.yml + parameters: + versionFileDirectory: '$(Build.SourcesDirectory)/onnxruntime' + workingDirectory: '$(Build.SourcesDirectory)/onnxruntime' + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact - Combined GPU' + inputs: + artifactName: 'onnxruntime-linux-x64-cuda' + targetPath: '$(Build.BinariesDirectory)/tgz-artifacts' + + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact - Combined GPU' + inputs: + artifactName: 'onnxruntime-linux-x64-tensorrt' + targetPath: '$(Build.BinariesDirectory)/tgz-artifacts' + + - task: ShellScript@2 + displayName: 'Shell Script' + inputs: + scriptPath: 'onnxruntime/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh' + args: '-a $(Build.BinariesDirectory)/tgz-artifacts' + workingDirectory: '$(Build.BinariesDirectory)/tgz-artifacts' + + - task: ArchiveFiles@2 + inputs: + rootFolderOrFile: '$(Build.BinariesDirectory)/tgz-artifacts/onnxruntime-linux-x64-gpu' + includeRootFolder: false + archiveType: 'tar' # Options: zip, 7z, tar, wim + tarCompression: 'gz' + archiveFile: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + replaceExistingArchive: true + + - template: ../templates/validate-package.yml + parameters: + PackageType: 'tarball' + PackagePath: '$(Build.ArtifactStagingDirectory)' + PackageName: 'onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + ScriptPath: '$(Build.SourcesDirectory)/onnxruntime/tools/nuget/validate_package.py' + PlatformsSupported: 'linux-x64' + VerifyNugetSigning: false + workingDirectory: '$(Build.ArtifactStagingDirectory)' + + + - task: CmdLine@2 + displayName: 'Test C API application for GPU package' + inputs: + script: | + docker run --gpus all -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e NVIDIA_VISIBLE_DEVICES=all --rm --volume $(Build.SourcesDirectory):/src_dir \ + --volume $(Build.ArtifactStagingDirectory):/artifact_src -e NIGHTLY_BUILD onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build \ + /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet/run_capi_application.sh -o /src_dir/onnxruntime -p /artifact_src/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz -w /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet + workingDirectory: '$(Build.ArtifactStagingDirectory)' + + - task: PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + artifactName: 'onnxruntime-linux-x64-gpu' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml new file mode 100644 index 0000000000000..3fb653c6b4405 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -0,0 +1,147 @@ +parameters: +- name: RunOnnxRuntimeTests + type: boolean + default: true + +- name: UseIncreasedTimeoutForTests + type: boolean + default: false + +- name: DoCompliance + type: boolean + default: true + +- name: DoEsrp + type: boolean + default: true + +- name: CudaVersion + type: string + default: '11.8' +- name: win_cuda_home + type: string +- name: win_trt_home + type: string + +stages: +# Windows CUDA without TensorRT Packaging +- template: ../templates/win-ci.yml + parameters: + ort_build_pool_name: 'onnxruntime-Win2022-GPU-T4' + DoCompliance: ${{ parameters.DoCompliance }} + DoEsrp: ${{ parameters.DoEsrp }} + stage_name_suffix: gpu + buildArch: x64 + msbuildPlatform: x64 + packageName: x64-cuda + CudaVersion: ${{ parameters.CudaVersion }} + buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" + runTests: ${{ parameters.RunOnnxRuntimeTests }} + buildJava: false + java_artifact_id: onnxruntime_gpu + PublishProtoc: true +# Windows CUDA with TensorRT Packaging +- template: ../templates/win-ci.yml + parameters: + ort_build_pool_name: 'onnxruntime-Win2022-GPU-T4' + DoCompliance: ${{ parameters.DoCompliance }} + DoEsrp: ${{ parameters.DoEsrp }} + stage_name_suffix: tensorrt + buildArch: x64 + msbuildPlatform: x64 + CudaVersion: ${{ parameters.CudaVersion }} + packageName: x64-tensorrt + buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" + runTests: ${{ parameters.RunOnnxRuntimeTests }} + buildJava: false + java_artifact_id: onnxruntime_gpu + UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} + +# Windows CUDA Combined Testing and Publishing +- stage: Windows_Packaging_combined_GPU + dependsOn: + - Windows_Packaging_gpu + - Windows_Packaging_tensorrt + condition: succeeded() + + jobs: + - job: + workspace: + clean: all + pool: 'onnxruntime-Win2022-GPU-T4' + variables: + CUDA_MODULE_LOADINGL: 'LAZY' + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + steps: + - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime + - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples + submodules: false + - script: dir $(Build.SourcesDirectory) + - template: ../templates/jobs/download_win_gpu_library.yml + parameters: + DownloadCUDA: true + DownloadTRT: true + CudaVersion: ${{ parameters.CudaVersion }} + + - template: ../templates/set-version-number-variables-step.yml + parameters: + versionFileDirectory: '$(Build.SourcesDirectory)\onnxruntime' + workingDirectory: '$(Build.SourcesDirectory)\onnxruntime' + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact - onnxruntime-win-x64-cuda' + inputs: + artifactName: 'onnxruntime-win-x64-cuda' + targetPath: '$(Build.BinariesDirectory)/zip-artifacts' + + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact - onnxruntime-win-x64-tensorrt' + inputs: + artifactName: 'onnxruntime-win-x64-tensorrt' + targetPath: '$(Build.BinariesDirectory)/zip-artifacts' + + - task: PowerShell@2 + displayName: 'PowerShell Script' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\onnxruntime\tools\ci_build\github\windows\extract_zip_files_gpu.ps1 + + - script: | + dir + workingDirectory: '$(Build.BinariesDirectory)/zip-artifacts' + displayName: 'List artifacts' + + - task: BatchScript@1 + displayName: 'Bundle CUDA/TRT EP binaries' + inputs: + filename: $(Build.SourcesDirectory)\onnxruntime\tools\ci_build\github\windows\bundle_dlls_gpu.bat + workingFolder: $(Build.BinariesDirectory)\zip-artifacts + + - task: CopyFiles@2 + displayName: 'Copy zip file to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\zip-artifacts' + Contents: 'onnxruntime-win-x64-gpu-*.zip' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - template: ../templates/validate-package.yml + parameters: + PackageType: 'zip' + PackagePath: '$(Build.ArtifactStagingDirectory)' + PackageName: 'onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip' + ScriptPath: '$(Build.SourcesDirectory)\onnxruntime\tools\nuget\validate_package.py' + PlatformsSupported: 'win-x64' + VerifyNugetSigning: false + workingDirectory: '$(Build.ArtifactStagingDirectory)' + + - task: BatchScript@1 + displayName: 'Test C API application for GPU package' + inputs: + filename: $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet\run_capi_application.bat + arguments: $(Build.SourcesDirectory)\onnxruntime $(Build.ArtifactStagingDirectory)\onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet + workingFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline Combined GPU Package Artifact' + inputs: + artifactName: 'onnxruntime-win-x64-gpu' + targetPath: '$(Build.ArtifactStagingDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 4ce39ecc35bfb..87fd4de7d3127 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -117,32 +117,32 @@ stages: - script: | set -e -x - python3 tools/ci_build/github/apple/build_ios_framework.py \ - --build_dir "$(Build.BinariesDirectory)/ios_framework" \ + python3 tools/ci_build/github/apple/build_apple_framework.py \ + --build_dir "$(Build.BinariesDirectory)/apple_framework" \ --path_to_protoc_exe $(Build.BinariesDirectory)/protobuf_install/bin/protoc \ - tools/ci_build/github/apple/default_full_ios_framework_build_settings.json + tools/ci_build/github/apple/default_full_apple_framework_build_settings.json mkdir $(Build.BinariesDirectory)/artifacts - mkdir -p $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) - cp -R $(Build.BinariesDirectory)/ios_framework/framework_out/onnxruntime.xcframework \ - $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) + mkdir -p $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) + cp -R $(Build.BinariesDirectory)/apple_framework/framework_out/onnxruntime.xcframework \ + $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) pushd $(Build.BinariesDirectory)/artifacts_staging zip -vr $(Build.BinariesDirectory)/artifacts/onnxruntime_xcframework.zip \ - onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) + onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) popd - displayName: "Build iOS xcframework" + displayName: "Build Apple xcframework" - script: | - python3 tools/ci_build/github/apple/test_ios_packages.py \ + python3 tools/ci_build/github/apple/test_apple_packages.py \ --fail_if_cocoapods_missing \ - --framework_info_file "$(Build.BinariesDirectory)/ios_framework/framework_info.json" \ - --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \ + --framework_info_file "$(Build.BinariesDirectory)/apple_framework/xcframework_info.json" \ + --c_framework_dir "$(Build.BinariesDirectory)/apple_framework/framework_out" \ --variant Full - displayName: "Test iOS framework" + displayName: "Test Apple framework" - task: PublishBuildArtifacts@1 inputs: pathtoPublish: '$(Build.BinariesDirectory)/artifacts' - artifactName: 'onnxruntime-ios-full-xcframework' + artifactName: 'onnxruntime-apple-full-xcframework' - template: component-governance-component-detection-steps.yml parameters: @@ -304,9 +304,7 @@ stages: - job: workspace: clean: all - # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. - # VS2019 has no support for net6 and we need to use msbuild (from the VS install) to do the packing - pool: 'Azure-Pipelines-EO-Windows2022-aiinfra' + pool: 'onnxruntime-Win-CPU-2022' variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} breakCodesignValidationInjection: ${{ parameters.DoEsrp }} @@ -315,66 +313,86 @@ stages: steps: - checkout: self submodules: true - - task: DownloadPipelineArtifact@0 - displayName: 'Download win-x64 Pipeline Artifact' - inputs: - artifactName: 'onnxruntime-win-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@0 - displayName: 'Download win-x86 Pipeline Artifact' - inputs: - artifactName: 'onnxruntime-win-x86' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Win x64' + ArtifactName: 'onnxruntime-win-x64' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - - task: DownloadPipelineArtifact@0 - displayName: 'Download win-arm64 Pipeline Artifact' - inputs: - artifactName: 'onnxruntime-win-arm64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download win-x86 Pipeline Artifact' + ArtifactName: 'onnxruntime-win-x86' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - - task: DownloadPipelineArtifact@0 - displayName: 'Download win-arm Pipeline Artifact' - inputs: - artifactName: 'onnxruntime-win-arm' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download win-arm64 Pipeline Artifact' + ArtifactName: 'onnxruntime-win-arm64' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - - task: DownloadPipelineArtifact@0 - displayName: 'Download osx-x64 Pipeline Artifact' - inputs: - artifactName: 'onnxruntime-osx' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download win-arm Pipeline Artifact' + ArtifactName: 'onnxruntime-win-arm' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - - task: DownloadPipelineArtifact@0 - displayName: 'Download linux-x64 Pipeline Artifact' - inputs: - artifactName: 'onnxruntime-linux-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download osx-x64 Pipeline Artifact' + ArtifactName: 'onnxruntime-osx' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet' - inputs: - artifactName: 'onnxruntime-linux-aarch64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download linux-x64 Pipeline Artifact' + ArtifactName: 'onnxruntime-linux-x64' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - - task: DownloadPipelineArtifact@2 - displayName: 'Download iOS Pipeline Artifact' - inputs: - artifactName: 'onnxruntime-ios-full-xcframework' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download linux-aarch64 Pipeline Artifact' + ArtifactName: 'onnxruntime-linux-aarch64' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - - task: DownloadPipelineArtifact@2 - displayName: 'Download android-full-aar Pipeline Artifact' - inputs: - artifactName: 'onnxruntime-android-full-aar' - patterns: '**/*.aar' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download iOS Pipeline Artifact' + ArtifactName: 'onnxruntime-ios-full-xcframework' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - - task: DownloadPipelineArtifact@0 - displayName: 'Download drop-extra Pipeline Artifact' - inputs: - artifactName: 'drop-extra' - targetPath: '$(Build.BinariesDirectory)/extra-artifact' + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Android-full-aar Pipeline Artifact' + ArtifactName: 'onnxruntime-android-full-aar' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download drop-extra Pipeline Artifact' + ArtifactName: 'drop-extra' + TargetPath: '$(Build.BinariesDirectory)/extra-artifact' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} - script: | dir diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index f2deb2041e06e..7484e0285fd2c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.118 + version: 1.0.120 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.118 + version: 1.0.120 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index ff7f0957e94ba..b7ae9ffa3c219 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -13,7 +13,6 @@ parameters: - 12.2 steps: - - ${{ if eq(parameters.DownloadCUDA, true) }}: - powershell: | azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.CudaVersion }} $(Agent.TempDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml index 85562d7758ab2..7693e8f2cd21c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml @@ -23,12 +23,33 @@ parameters: type: string default: '' +- name: CudaVersion + displayName: CUDA version + type: string + default: '11.8' + values: + - 11.8 + - 12.2 + + + # We only have CUDA/TRT on x64. We do not have a build for CUDA/TRT for ARM64. # Therefore this file does not have an `OnnxruntimeNodejsBindingArch` parameter stages: - stage: Linux_C_API_Packaging_GPU_TensorRT_x64 dependsOn: [] + variables: + - name: linux_trt_version + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: 8.6.1.6-1.cuda11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 + - name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 jobs: - job: dependsOn: [] @@ -37,7 +58,13 @@ stages: timeoutInMinutes: 180 pool: 'Onnxruntime-Linux-GPU' variables: - CUDA_VERSION: '11.8' + - name: CUDA_VERSION_MAJOR + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: '11' + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: '12' + - name: CUDA_VERSION + value: ${{ parameters.CudaVersion }} steps: - checkout: self clean: true @@ -48,11 +75,11 @@ stages: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg BASEIMAGE=${{ variables.docker_base_image }} + --build-arg TRT_VERSION=${{ variables.linux_trt_version }} --build-arg BUILD_UID=$( id -u ) " - Repository: onnxruntimecuda118xtrt86build + Repository: onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build - template: set-version-number-variables-step.yml - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml index 6b5fba7785fe0..00ba5ea4a475a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml @@ -168,7 +168,7 @@ stages: inputs: filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' workingDirectory: '$(Build.BinariesDirectory)' - arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\installed -build_config $(BuildConfig) + arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\$(BuildConfig)\installed -build_config $(BuildConfig) - task: PythonScript@0 displayName: 'Generate cmake config' diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 33f956f931f18..47cd72f412c67 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -126,7 +126,7 @@ stages: BuildStep: - script: | set -e -x - python $(Build.SourcesDirectory)/tools/ci_build/github/apple/build_and_assemble_ios_pods.py \ + python $(Build.SourcesDirectory)/tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ --build-dir "$(Build.BinariesDirectory)/ios_framework_full" \ --staging-dir "$(Build.BinariesDirectory)/staging" \ --variant Full \ @@ -134,7 +134,7 @@ stages: -b="--path_to_protoc_exe" -b "$(Build.BinariesDirectory)/installed/bin/protoc" # Mobile build: - # python $(Build.SourcesDirectory)/tools/ci_build/github/apple/build_and_assemble_ios_pods.py \ + # python $(Build.SourcesDirectory)/tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ # --build_dir $(Build.BinariesDirectory)/ios_framework_mobile \ # --staging-dir "$(Build.BinariesDirectory)/staging" \ # --include_ops_by_config $(Build.SourcesDirectory)/tools/ci_build/github/android/mobile_package.required_operators.config \ diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index 81f17a26b16a6..1a7915172e211 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -29,7 +29,7 @@ stages: objcPodName: onnxruntime-mobile-objc ${{ if eq(parameters.packageVariant, 'Full') }}: - buildSettingsFile: "tools/ci_build/github/apple/default_full_ios_framework_build_settings.json" + buildSettingsFile: "tools/ci_build/github/apple/default_full_apple_framework_build_settings.json" cPodName: onnxruntime-c objcPodName: onnxruntime-objc @@ -38,7 +38,7 @@ stages: cPodName: onnxruntime-training-c objcPodName: onnxruntime-training-objc - timeoutInMinutes: 120 + timeoutInMinutes: 180 steps: - script: | @@ -84,8 +84,8 @@ stages: # create and test mobile pods - script: | - python tools/ci_build/github/apple/build_and_assemble_ios_pods.py \ - --build-dir "$(Build.BinariesDirectory)/ios_framework" \ + python tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ + --build-dir "$(Build.BinariesDirectory)/apple_framework" \ --staging-dir "$(Build.BinariesDirectory)/staging" \ --pod-version "$(ortPodVersion)" \ --test \ @@ -93,13 +93,13 @@ stages: --build-settings-file "${{ variables.buildSettingsFile }}" \ ${{ variables.optionalIncludeOpsByConfigOption }} \ -b="--path_to_protoc_exe=$(Build.BinariesDirectory)/protobuf_install/bin/protoc" - displayName: "Build iOS framework and assemble pod package files" + displayName: "Build macOS/iOS framework and assemble pod package files" - script: | - python tools/ci_build/github/apple/test_ios_packages.py \ + python tools/ci_build/github/apple/test_apple_packages.py \ --fail_if_cocoapods_missing \ - --framework_info_file "$(Build.BinariesDirectory)/ios_framework/framework_info.json" \ - --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \ + --framework_info_file "$(Build.BinariesDirectory)/apple_framework/xcframework_info.json" \ + --c_framework_dir "$(Build.BinariesDirectory)/apple_framework/framework_out" \ --variant ${{ parameters.packageVariant }} \ --test_project_stage_dir "$(Build.BinariesDirectory)/app_center_test" \ --prepare_test_project_only @@ -109,7 +109,7 @@ stages: inputs: actions: 'build-for-testing' configuration: 'Debug' - xcWorkspacePath: '$(Build.BinariesDirectory)/app_center_test/ios_package_test/ios_package_test.xcworkspace' + xcWorkspacePath: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/apple_package_test.xcworkspace' sdk: 'iphoneos' scheme: 'ios_package_test' xcodeVersion: 'specifyPath' @@ -118,8 +118,8 @@ stages: signingIdentity: '$(APPLE_CERTIFICATE_SIGNING_IDENTITY)' provisioningProfileName: 'temporary *' # temporary name, change it back to the original below later #provisioningProfileName: 'iOS Team Provisioning Profile' - args: '-derivedDataPath $(Build.BinariesDirectory)/app_center_test/ios_package_test/DerivedData' - workingDirectory: '$(Build.BinariesDirectory)/app_center_test/ios_package_test/' + args: '-derivedDataPath $(Build.BinariesDirectory)/app_center_test/apple_package_test/DerivedData' + workingDirectory: '$(Build.BinariesDirectory)/app_center_test/apple_package_test/' useXcpretty: false # xcpretty can hide useful error output so we will disable it displayName: 'Build App Center iPhone arm64 tests' @@ -130,7 +130,7 @@ stages: --devices $(app_center_test_devices) \ --test-series "master" \ --locale "en_US" \ - --build-dir $(Build.BinariesDirectory)/app_center_test/ios_package_test/DerivedData/Build/Products/Debug-iphoneos \ + --build-dir $(Build.BinariesDirectory)/app_center_test/apple_package_test/DerivedData/Build/Products/Debug-iphoneos \ --token $(app_center_api_token) displayName: "Run E2E tests on App Center" @@ -139,7 +139,7 @@ stages: for POD_NAME in "${{ variables.cPodName}}" "${{ variables.objcPodName }}"; do - ./tools/ci_build/github/apple/assemble_ios_packaging_artifacts.sh \ + ./tools/ci_build/github/apple/assemble_apple_packaging_artifacts.sh \ "$(Build.BinariesDirectory)/staging" \ "$(Build.ArtifactStagingDirectory)" \ "${POD_NAME}" \ diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 8d28b4ce580b4..a31b2fedbf217 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -11,6 +11,7 @@ parameters: - name: EnvSetupScript type: string + default: '' - name: buildArch type: string @@ -63,11 +64,24 @@ parameters: type: boolean default: false +- name: PublishProtoc + type: boolean + default: false + +- name: CudaVersion + type: string + default: '11.8' + values: + - 11.8 + - 12.2 + stages: - stage: Windows_Packaging_${{ parameters.stage_name_suffix }} dependsOn: [] variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' + CUDA_MODULE_LOADING: 'LAZY' jobs: - job: workspace: @@ -102,12 +116,26 @@ stages: condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: versionSpec: '18.x' + - ${{ if ne(parameters.EnvSetupScript, '') }}: + - template: jobs/set-winenv.yml + parameters: + EnvSetupScript: ${{ parameters.EnvSetupScript }} + ${{ if contains(parameters.buildparameter, 'use_cuda') }}: + DownloadCUDA: true - - template: jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.EnvSetupScript }} - ${{ if contains(parameters.buildparameter, 'use_cuda') }}: - DownloadCUDA: true + - ${{ if eq(parameters.EnvSetupScript, '') }}: + - template: jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + ${{ if contains(parameters.buildparameter, 'use_cuda') }}: + DownloadCUDA: true + ${{ if contains(parameters.buildparameter, 'use_tensorrt') }}: + DownloadCUDA: true + DownloadTRT: true + - powershell: | + Write-Host "##vso[task.prependpath]C:\Program Files (x86)\dotnet" + displayName: 'Append dotnet x86 Directory to PATH' + condition: and(succeeded(), eq('${{ parameters.buildArch}}', 'x86')) - template: download-deps.yml @@ -180,7 +208,8 @@ stages: #Upload protoc.exe, which will be used in nuget build for generating C# files - task: PublishPipelineArtifact@1 - condition: and(succeeded(), eq('${{ parameters.packageName}}', 'x64')) + displayName: Publish protoc as drop-extra + condition: and(succeeded(), or(eq('${{ parameters.packageName}}', 'x64'), eq('${{ parameters.PublishProtoc}}', true))) inputs: targetPath: '$(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe' artifactName: 'drop-extra${{ parameters.artifact_name_suffix }}' @@ -194,13 +223,6 @@ stages: Contents: 'custom_op_library.dll' TargetFolder: '$(Build.ArtifactStagingDirectory)/testdata' - #To be used in test_win.yml - - task: PublishPipelineArtifact@1 - condition: and(succeeded(), eq('${{ parameters.packageName}}', 'x64')) - inputs: - targetPath: '$(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe' - artifactName: 'drop-nuget${{ parameters.artifact_name_suffix }}' - - task: CmdLine@2 condition: and(succeeded(), eq('${{ parameters.buildJava}}', true)) displayName: 'Add symbols and notices to Java' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 65fcf98634456..b7ec3305003d7 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -95,6 +95,18 @@ jobs: targetFolder: $(Build.SourcesDirectory)\js\web\lib\wasm\binding flattenFolders: true displayName: 'Binplace js files' + - script: | + npm i -g puppeteer + workingDirectory: '$(Build.SourcesDirectory)' + displayName: 'Use puppeteer to prepare Chrome for tests' + - script: | + FOR /F "tokens=* USEBACKQ" %%F IN (`where /r %HOMEDRIVE%%HOMEPATH%\.cache\puppeteer chrome.exe`) DO ( + SET var=%%F + ECHO found chrome.exe: %%F + ) + ECHO ##vso[task.setvariable variable=CHROME_BIN;]%var% + workingDirectory: '$(Build.SourcesDirectory)' + displayName: 'Set CHROME_BIN' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' @@ -156,85 +168,37 @@ jobs: workingDirectory: $(Build.BinariesDirectory) errorActionPreference: stop displayName: 'Pack NPM packages' - - task: PowerShell@2 - inputs: - targetType: 'inline' - script: Get-WmiObject Win32_Process -Filter "name = 'msedge.exe'" | Select-Object CommandLine | Format-List - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'Dump active Edge processes (before tests 0)' - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack + npm test -- -e=chrome -b=webgl,wasm,xnnpack workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) + npm test -- -e=chrome -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - - task: PowerShell@2 - inputs: - targetType: 'inline' - script: Get-WmiObject Win32_Process -Filter "name = 'msedge.exe'" | Select-Object CommandLine | Format-List - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'Dump active Edge processes (before tests 1)' - script: | - npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) + npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - # temporarily allow this test to fail, so that people are not blocked. - # investigation is ongoing for the root cause of the random failure (Edge crash). - # TODO: remove this line once the root cause is found and fixed. - continueOnError: true - - task: PowerShell@2 - inputs: - targetType: 'inline' - script: Get-WmiObject Win32_Process -Filter "name = 'msedge.exe'" | Select-Object CommandLine | Format-List - workingDirectory: '$(Build.SourcesDirectory)\js\web' - condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - displayName: 'Dump active Edge processes (before tests 2)' - script: | - npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) + npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - - task: PowerShell@2 - inputs: - targetType: 'inline' - script: Get-WmiObject Win32_Process -Filter "name = 'msedge.exe'" | Select-Object CommandLine | Format-List - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'Dump active Edge processes (before tests 3)' - script: | - npm test -- --webgl-texture-pack-mode -b=webgl -e=edge + npm test -- --webgl-texture-pack-mode -b=webgl -e=chrome workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebGL: packed mode' - - task: PowerShell@2 - inputs: - targetType: 'inline' - script: Get-WmiObject Win32_Process -Filter "name = 'msedge.exe'" | Select-Object CommandLine | Format-List - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'Dump active Edge processes (before tests 4)' - script: | - npm test -- --wasm-enable-proxy -b=wasm -e=edge + npm test -- --wasm-enable-proxy -b=wasm -e=chrome workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebAssembly: proxy' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) - - task: PowerShell@2 - inputs: - targetType: 'inline' - script: Get-WmiObject Win32_Process -Filter "name = 'msedge.exe'" | Select-Object CommandLine | Format-List - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'Dump active Edge processes (before E2E tests)' - - task: PowerShell@2 - inputs: - targetType: 'inline' - script: dir -r $(Build.SourcesDirectory)\build\js\e2e - workingDirectory: '$(Build.SourcesDirectory)\js\web' - errorActionPreference: continue - displayName: 'Dump E2E test folder (before E2E tests)' - script: | - npm run test:e2e -- --browser=Edge_default + npm run test:e2e -- --browser=Chrome_default workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'E2E package consuming test' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) diff --git a/tools/ci_build/github/js/react_native_e2e_full_ios_framework_build_settings.json b/tools/ci_build/github/js/react_native_e2e_full_ios_framework_build_settings.json index d15326de41099..78de7edb5ec29 100644 --- a/tools/ci_build/github/js/react_native_e2e_full_ios_framework_build_settings.json +++ b/tools/ci_build/github/js/react_native_e2e_full_ios_framework_build_settings.json @@ -4,13 +4,17 @@ "x86_64" ] }, - "build_params": [ - "--ios", - "--parallel", - "--use_xcode", - "--build_apple_framework", - "--use_coreml", - "--skip_tests", - "--apple_deploy_target=12.0" - ] + "build_params": { + "base": [ + "--parallel", + "--use_xcode", + "--build_apple_framework", + "--use_coreml", + "--skip_tests" + ], + "iphonesimulator": [ + "--ios", + "--apple_deploy_target=12.0" + ] + } } diff --git a/tools/ci_build/github/js/react_native_e2e_mobile_ios_framework_build_settings.json b/tools/ci_build/github/js/react_native_e2e_mobile_ios_framework_build_settings.json index e733885399f72..3d80231393cc6 100644 --- a/tools/ci_build/github/js/react_native_e2e_mobile_ios_framework_build_settings.json +++ b/tools/ci_build/github/js/react_native_e2e_mobile_ios_framework_build_settings.json @@ -4,18 +4,22 @@ "x86_64" ] }, - "build_params": [ - "--ios", - "--parallel", - "--use_xcode", - "--build_apple_framework", - "--minimal_build=extended", - "--disable_rtti", - "--disable_ml_ops", - "--disable_exceptions", - "--enable_reduced_operator_type_support", - "--use_coreml", - "--skip_tests", - "--apple_deploy_target=12.0" - ] + "build_params": { + "base": [ + "--parallel", + "--use_xcode", + "--build_apple_framework", + "--minimal_build=extended", + "--disable_rtti", + "--disable_ml_ops", + "--disable_exceptions", + "--enable_reduced_operator_type_support", + "--use_coreml", + "--skip_tests" + ], + "iphonesimulator": [ + "--ios", + "--apple_deploy_target=12.0" + ] + } } diff --git a/tools/ci_build/github/linux/build_cuda_c_api_package.sh b/tools/ci_build/github/linux/build_cuda_c_api_package.sh index 5cd1c8c243050..2ec8bc82ae048 100755 --- a/tools/ci_build/github/linux/build_cuda_c_api_package.sh +++ b/tools/ci_build/github/linux/build_cuda_c_api_package.sh @@ -4,7 +4,7 @@ export CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protect export CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" docker run --gpus all -e CFLAGS -e CXXFLAGS -e NVIDIA_VISIBLE_DEVICES=all --rm --volume \ $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ ---volume /data/models:/build/models:ro --volume /data/onnx:/data/onnx:ro -e NIGHTLY_BUILD onnxruntimecuda11centosbuild \ +--volume /data/models:/build/models:ro --volume /data/onnx:/data/onnx:ro -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}build \ /usr/bin/python3.9 /onnxruntime_src/tools/ci_build/build.py --build_java --build_nodejs --build_dir /build --config Release \ --skip_submodule_sync --parallel --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION \ --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION \ diff --git a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh index 18a32e3599391..5bf6a69170074 100755 --- a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh +++ b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh @@ -4,6 +4,6 @@ export CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protect export CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" mkdir -p $HOME/.onnx docker run --gpus all -e CFLAGS -e CXXFLAGS -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ ---volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda118xtrt86build \ +--volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \ /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release \ --skip_submodule_sync --parallel --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index d4aa9b269095f..8f265b208cd47 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -8,6 +8,7 @@ ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG DEVTOOLSET_ROOTPATH=/usr ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64 ARG PREPEND_PATH=/usr/local/cuda/binet +ARG TRT_VERSION=8.6.1.6-1.cuda11.8 #Build manylinux docker image begin FROM $BASEIMAGE AS runtime_base diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 new file mode 100644 index 0000000000000..a36f60b87768d --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 @@ -0,0 +1,180 @@ +ARG BASEIMAGE=nvidia/cuda:12.2.2-cudnn8-devel-ubi8 +ARG POLICY=manylinux2014 +ARG PLATFORM=x86_64 +ARG DEVTOOLSET_ROOTPATH= +ARG LD_LIBRARY_PATH_ARG= +ARG PREPEND_PATH= + +#We need both CUDA and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: +#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and +#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. +#So we use CUDA as the base image then add manylinux on top of it. + +#Build manylinux2014 docker image begin +FROM $BASEIMAGE AS runtime_base +ARG POLICY +ARG PLATFORM +ARG DEVTOOLSET_ROOTPATH +ARG LD_LIBRARY_PATH_ARG +ARG PREPEND_PATH +LABEL maintainer="The ManyLinux project" + +ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} +ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 +ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} +ENV PATH=${PREPEND_PATH}${PATH} +ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig + +# first copy the fixup mirrors script, keep the script around +COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors + +# setup entrypoint, this will wrap commands with `linux32` with i686 images +COPY build_scripts/install-entrypoint.sh \ + build_scripts/build_utils.sh \ + /build_scripts/ + +RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts +COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint +ENTRYPOINT ["manylinux-entrypoint"] + +COPY build_scripts/install-runtime-packages.sh \ + build_scripts/build_utils.sh \ + /build_scripts/ +RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ + +COPY build_scripts/build_utils.sh /build_scripts/ + +COPY build_scripts/install-autoconf.sh /build_scripts/ +RUN export AUTOCONF_ROOT=autoconf-2.71 && \ + export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ + export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ + manylinux-entrypoint /build_scripts/install-autoconf.sh + +COPY build_scripts/install-automake.sh /build_scripts/ +RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ + export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ + export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ + manylinux-entrypoint /build_scripts/install-automake.sh + +COPY build_scripts/install-libtool.sh /build_scripts/ +RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ + export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ + export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ + manylinux-entrypoint /build_scripts/install-libtool.sh + +COPY build_scripts/install-libxcrypt.sh /build_scripts/ +RUN export LIBXCRYPT_VERSION=4.4.28 && \ + export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ + export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ + export PERL_ROOT=perl-5.34.0 && \ + export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ + export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ + manylinux-entrypoint /build_scripts/install-libxcrypt.sh + +FROM runtime_base AS build_base +COPY build_scripts/install-build-packages.sh /build_scripts/ +RUN manylinux-entrypoint /build_scripts/install-build-packages.sh + + +FROM build_base AS build_git +COPY build_scripts/build-git.sh /build_scripts/ +RUN export GIT_ROOT=git-2.36.2 && \ + export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ + export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ + manylinux-entrypoint /build_scripts/build-git.sh + + +FROM build_base AS build_cpython +COPY build_scripts/build-sqlite3.sh /build_scripts/ +RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ + export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ + export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ + manylinux-entrypoint /build_scripts/build-sqlite3.sh + +COPY build_scripts/build-openssl.sh /build_scripts/ +RUN export OPENSSL_ROOT=openssl-1.1.1q && \ + export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ + export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ + manylinux-entrypoint /build_scripts/build-openssl.sh + +COPY build_scripts/build-cpython.sh /build_scripts/ + + +FROM build_cpython AS build_cpython38 +COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 + + +FROM build_cpython AS build_cpython39 +COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 + + +FROM build_cpython AS build_cpython310 +COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 + +FROM build_cpython AS build_cpython311 +COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 + +FROM build_cpython AS all_python +COPY build_scripts/install-pypy.sh \ + build_scripts/pypy.sha256 \ + build_scripts/finalize-python.sh \ + /build_scripts/ +RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 +RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 +COPY --from=build_cpython38 /opt/_internal /opt/_internal/ +COPY --from=build_cpython39 /opt/_internal /opt/_internal/ +COPY --from=build_cpython310 /opt/_internal /opt/_internal/ +COPY --from=build_cpython311 /opt/_internal /opt/_internal/ +RUN manylinux-entrypoint /build_scripts/finalize-python.sh + + +FROM runtime_base +COPY --from=build_git /manylinux-rootfs / +COPY --from=build_cpython /manylinux-rootfs / +COPY --from=all_python /opt/_internal /opt/_internal/ +COPY build_scripts/finalize.sh \ + build_scripts/python-tag-abi-tag.py \ + build_scripts/requirements3.8.txt \ + build_scripts/requirements3.9.txt \ + build_scripts/requirements3.10.txt \ + build_scripts/requirements3.11.txt \ + build_scripts/requirements-base-tools.txt \ + /build_scripts/ +COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ +RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts + +ENV SSL_CERT_FILE=/opt/_internal/certs.pem + +CMD ["/bin/bash"] + +#Build manylinux2014 docker image end +ARG PYTHON_VERSION=3.9 +ARG TORCH_VERSION=2.1.0 +ARG OPSET_VERSION=15 +ARG INSTALL_DEPS_EXTRA_ARGS + +#Add our own dependencies +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && \ + /tmp/scripts/manylinux/install_centos.sh && \ + /tmp/scripts/install_os_deps.sh -d gpu $INSTALL_DEPS_EXTRA_ARGS && \ + /tmp/scripts/install_rust.sh + +ENV PATH="/root/.cargo/bin/:$PATH" + +RUN /tmp/scripts/install_ninja.sh && \ + /tmp/scripts/install_python_deps.sh -d gpu -v 12.2 -p $PYTHON_VERSION -h $TORCH_VERSION $INSTALL_DEPS_EXTRA_ARGS && \ + rm -rf /tmp/scripts + +ARG BUILD_UID=1001 +ARG BUILD_USER=onnxruntimedev +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER +ENV PATH /usr/local/dotnet:$PATH +ENV ORTMODULE_ONNX_OPSET_VERSION=$OPSET_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 index bbdb411b790a0..8ef8e05b8ac77 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -5,8 +5,10 @@ # Dockerfile to Test ONNX Runtime on UBI8 with CUDA 11.8 and TensorRT 8.6 # Build base image with required system packages -FROM nvidia/cuda:11.8.0-cudnn8-devel-ubi8 AS base - +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +ARG TRT_VERSION=8.6.1.6-1.cuda11.8 +FROM $BASEIMAGE AS base +ARG TRT_VERSION ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} RUN dnf install -y bash wget &&\ @@ -26,8 +28,7 @@ RUN pip3 install setuptools>=68.2.2 # Install TensorRT RUN dnf install -y libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 -RUN v="8.6.1.6-1+cuda11.8" &&\ - dnf downgrade -y libnvinfer8-${v} libnvinfer8-${v} libnvonnxparsers8-${v} libnvparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-lean8-${v} libnvinfer-vc-plugin8-${v} libnvinfer-dispatch8-${v} &&\ +RUN dnf downgrade -y libnvinfer8-${TRT_VERSION} libnvinfer8-${TRT_VERSION} libnvonnxparsers8-${TRT_VERSION} libnvparsers8-${TRT_VERSION} libnvinfer-plugin8-${TRT_VERSION} libnvinfer-lean8-${TRT_VERSION} libnvinfer-vc-plugin8-${TRT_VERSION} libnvinfer-dispatch8-${TRT_VERSION} &&\ dnf install -y dnf-plugin-versionlock &&\ dnf versionlock libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 RUN dnf clean dbcache diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu similarity index 50% rename from tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 rename to tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu index 83a974469234f..9b9dc9ecae822 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu @@ -5,11 +5,16 @@ # Dockerfile to run ONNXRuntime with TensorRT integration # Build base image with required system packages -FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base - +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 +ARG TRT_VERSION=8.6.1.6-1+cuda11.8 +ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64:/usr/local/cuda/lib64 +FROM $BASEIMAGE AS base +ARG TRT_VERSION ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} + RUN apt-get update &&\ apt-get install -y git bash wget @@ -24,12 +29,11 @@ RUN apt-get install -y --no-install-recommends \ RUN pip install --upgrade pip # Install TensorRT -RUN v="8.6.1.6-1+cuda11.8" &&\ - apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ - apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ - libnvinfer-headers-dev=${v} libnvinfer-headers-plugin-dev=${v} libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} libnvinfer-lean-dev=${v} libnvinfer-vc-plugin-dev=${v} libnvinfer-dispatch-dev=${v}\ - python3-libnvinfer=${v} libnvinfer-samples=${v} tensorrt-dev=${v} tensorrt-libs=${v} + apt-get install -y libnvinfer8=${TRT_VERSION} libnvonnxparsers8=${TRT_VERSION} libnvparsers8=${TRT_VERSION} libnvinfer-plugin8=${TRT_VERSION} libnvinfer-lean8=${TRT_VERSION} libnvinfer-vc-plugin8=${TRT_VERSION} libnvinfer-dispatch8=${TRT_VERSION}\ + libnvinfer-headers-dev=${TRT_VERSION} libnvinfer-headers-plugin-dev=${TRT_VERSION} libnvinfer-dev=${TRT_VERSION} libnvonnxparsers-dev=${TRT_VERSION} libnvparsers-dev=${TRT_VERSION} libnvinfer-plugin-dev=${TRT_VERSION} libnvinfer-lean-dev=${TRT_VERSION} libnvinfer-vc-plugin-dev=${TRT_VERSION} libnvinfer-dispatch-dev=${TRT_VERSION}\ + python3-libnvinfer=${TRT_VERSION} libnvinfer-samples=${TRT_VERSION} tensorrt-dev=${TRT_VERSION} tensorrt-libs=${TRT_VERSION} ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile index 318791072f46d..b1ff40e8effef 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile @@ -2,8 +2,8 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - +ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +FROM $BASEIMAGE ENV PATH /usr/lib/jvm/msopenjdk-11/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt new file mode 100644 index 0000000000000..152a17db90366 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt @@ -0,0 +1,7 @@ +--pre +-f https://download.pytorch.org/whl/torch_stable.html +torch==2.1.0+cu121 +torchvision==0.16.0+cu121 +torchtext==0.16.0 +packaging==23.1 +setuptools>=68.2.2 diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp new file mode 100644 index 0000000000000..a89ac561f8860 --- /dev/null +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "lib/Api/pch/pch.h" + +#include "HardwareCoreEnumerator.h" + +namespace WINMLP { + +struct LogicalProcessorInformation { + std::unique_ptr Buffer; + size_t Length; +}; + +struct CoreCounter { + uint32_t PhysicalCores = 0; + uint32_t SocDieCores = 0; +}; + +static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { + DWORD length = 0; + DWORD rc = GetLogicalProcessorInformationEx(relationship, nullptr, &length); + + assert(rc == FALSE); + + auto processorInformationBytes = std::make_unique(length); + + rc = GetLogicalProcessorInformationEx( + relationship, reinterpret_cast(processorInformationBytes.get()), &length + ); + + assert(rc == TRUE); + + return {std::move(processorInformationBytes), length}; +} + +uint32_t CountSetBits(DWORD input) { + uint32_t c; + for (c = 0; input; c++) { + input &= input - 1; + } + return c; +} + +static CoreCounter GetNumberOPhysicalAndEngineeringCores() { + auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll); + + CoreCounter cores; + DWORD dwLevel2GroupMask = 0; + DWORD dwLevel3GroupMask = 0; + size_t read = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX currentProcessorInfo = NULL; + + while ((read + FIELD_OFFSET(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, Processor)) < logicalProcessorInformation.Length + ) { + currentProcessorInfo = + reinterpret_cast(logicalProcessorInformation.Buffer.get() + read); + if ((read + currentProcessorInfo->Size) > logicalProcessorInformation.Length) { + break; + } + + switch (currentProcessorInfo->Relationship) { + case RelationProcessorCore: + cores.PhysicalCores++; + break; + case RelationCache: + if (currentProcessorInfo->Cache.Level == 2) { + dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } else if (currentProcessorInfo->Cache.Level == 3) { + dwLevel3GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } + break; + } + + read += currentProcessorInfo->Size; + } + + cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + return cores; +} + +uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { + // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. + // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. + auto cores = GetNumberOPhysicalAndEngineeringCores(); + // We want to use the number of physical cores, but exclude soc cores + return cores.PhysicalCores - cores.SocDieCores; +} + +} // namespace WINMLP diff --git a/winml/lib/Api/HardwareCoreEnumerator.h b/winml/lib/Api/HardwareCoreEnumerator.h new file mode 100644 index 0000000000000..6861ba7d46bcf --- /dev/null +++ b/winml/lib/Api/HardwareCoreEnumerator.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace WINMLP { +struct HardwareCoreEnumerator { + HardwareCoreEnumerator() = delete; + static uint32_t DefaultIntraOpNumThreads(); +}; +} // namespace WINMLP diff --git a/winml/lib/Api/LearningModelDevice.cpp b/winml/lib/Api/LearningModelDevice.cpp index c9c6f5bc70ee2..9f48ee03886e1 100644 --- a/winml/lib/Api/LearningModelDevice.cpp +++ b/winml/lib/Api/LearningModelDevice.cpp @@ -7,6 +7,7 @@ #include #include #include "D3DDeviceCache.h" +#include "HardwareCoreEnumerator.h" #include "ConverterResourceStore.h" @@ -131,7 +132,7 @@ LearningModelDevice::CacheThreadPool(_winml::IThreading* thread_pool) { uint32_t LearningModelDevice::NumberOfIntraOpThreads() { if (IsCpuDevice()) { - return std::thread::hardware_concurrency(); + return HardwareCoreEnumerator::DefaultIntraOpNumThreads(); } else { // GPU sessions should not rely on intra op threads. // Creating a large thread pool is unnecessary and wasteful, and can cause diff --git a/winml/lib/Api/LearningModelSessionOptions.cpp b/winml/lib/Api/LearningModelSessionOptions.cpp index 2ff9c6d1d56d0..374200fb3b9f8 100644 --- a/winml/lib/Api/LearningModelSessionOptions.cpp +++ b/winml/lib/Api/LearningModelSessionOptions.cpp @@ -3,11 +3,20 @@ #include "lib/Api/pch/pch.h" #include "LearningModelSessionOptions.h" +#include "HardwareCoreEnumerator.h" namespace WINMLP { + +LearningModelSessionOptions::LearningModelSessionOptions() { + intra_op_num_threads_override_ = HardwareCoreEnumerator::DefaultIntraOpNumThreads(); +} + LearningModelSessionOptions::LearningModelSessionOptions(const LearningModelSessionOptions& options) : batch_size_override_(options.batch_size_override_), - close_model_on_session_creation_(options.close_model_on_session_creation_) { + close_model_on_session_creation_(options.close_model_on_session_creation_), + named_dim_overrides_(options.named_dim_overrides_), + intra_op_num_threads_override_(options.intra_op_num_threads_override_), + custom_ops_lib_paths_(options.custom_ops_lib_paths_) { } uint32_t LearningModelSessionOptions::BatchSizeOverride() { diff --git a/winml/lib/Api/LearningModelSessionOptions.h b/winml/lib/Api/LearningModelSessionOptions.h index 5fc7e54997403..21d0242735f94 100644 --- a/winml/lib/Api/LearningModelSessionOptions.h +++ b/winml/lib/Api/LearningModelSessionOptions.h @@ -11,7 +11,7 @@ struct LearningModelSessionOptions : LearningModelSessionOptionsT< LearningModelSessionOptions, ILearningModelSessionOptionsNative, ILearningModelSessionOptionsNative1> { - LearningModelSessionOptions() = default; + LearningModelSessionOptions(); LearningModelSessionOptions(const LearningModelSessionOptions& options); @@ -72,7 +72,7 @@ struct LearningModelSessionOptions : LearningModelSessionOptionsT< // The intra operator num threads property is used to control the number of threads used in the threadpool for intra operator calculations. // The default value here is the maximum number of logical cores to ensure that the default behavior of WinML always runs the fastest. // WARNING: Setting a number higher than the maximum number of logical cores may result in an inefficient threadpool - uint32_t intra_op_num_threads_override_ = std::thread::hardware_concurrency(); + uint32_t intra_op_num_threads_override_; bool allow_thread_spinning_ = true; diff --git a/winml/test/api/LearningModelSessionAPITest.cpp b/winml/test/api/LearningModelSessionAPITest.cpp index 4ec79b8a0f4c6..d6e70e35e3a6d 100644 --- a/winml/test/api/LearningModelSessionAPITest.cpp +++ b/winml/test/api/LearningModelSessionAPITest.cpp @@ -2195,12 +2195,6 @@ static void SetIntraOpNumThreads() { auto binding = LearningModelBinding(session); binding.Bind(L"input", tensor_input); WINML_EXPECT_NO_THROW(session.Evaluate(binding, L"")); - - // Check to verify that the default number of threads in LearningModelSession is equal to the number of logical cores. - session = LearningModelSession(model, device); - nativeSession = session.as(); - WINML_EXPECT_NO_THROW(nativeSession->GetIntraOpNumThreads(&numIntraOpThreads)); - WINML_EXPECT_EQUAL(std::thread::hardware_concurrency(), numIntraOpThreads); } static void SetIntraOpThreadSpinning() {