From 348a963238a5ebd242259ffb9e158034cdd052af Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 30 Oct 2023 13:48:34 -0700 Subject: [PATCH 01/21] [DML EP] Handle non-raw data in dynamic graph compilation (#18160) --- .../src/DmlGraphFusionHelper.cpp | 59 +++++++++++-------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index cd74e7fa92940..4f7ec188140b5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -103,6 +103,36 @@ namespace DmlGraphFusionHelper ORT_THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); } + std::tuple, std::vector, std::byte*, size_t> UnpackInitializer( + const onnxruntime::Graph& graph, + const ONNX_NAMESPACE::TensorProto* initializer) + { + std::unique_ptr unpackedTensor; + std::vector unpackedExternalTensor; + std::byte* tensorPtr = nullptr; + size_t tensorByteSize = 0; + + // The tensor may be stored as raw data or in typed fields. + if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); + tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); + tensorByteSize = unpackedExternalTensor.size(); + } + else if (initializer->has_raw_data()) + { + tensorPtr = (std::byte*)(initializer->raw_data().c_str()); + tensorByteSize = initializer->raw_data().size(); + } + else + { + std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); + tensorPtr = unpackedTensor.get(); + } + + return std::make_tuple(std::move(unpackedTensor), std::move(unpackedExternalTensor), tensorPtr, tensorByteSize); + } + void ProcessInputData( const ExecutionProviderImpl* providerImpl, const std::vector& isInputsUploadedByDmlEP, @@ -161,32 +191,11 @@ namespace DmlGraphFusionHelper auto iter = initializerNameToInitializerMap.find(subGraphInputArgNames[i]); if (iter != initializerNameToInitializerMap.end()) { - std::byte* tensorPtr = nullptr; - size_t tensorByteSize = 0; - std::vector unpackedExternalTensor; - - std::unique_ptr unpackedTensor; - - //auto& initializer = iter->second; auto* initializer = iter->second.first; + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, initializer); - // The tensor may be stored as raw data or in typed fields. - if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) - { - THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); - tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); - tensorByteSize = unpackedExternalTensor.size(); - } - else if (initializer->has_raw_data()) + if (initializer->data_location() != onnx::TensorProto_DataLocation_EXTERNAL && !initializer->has_raw_data()) { - tensorPtr = (std::byte*)(initializer->raw_data().c_str()); - tensorByteSize = initializer->raw_data().size(); - } - else - { - std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); - tensorPtr = unpackedTensor.get(); - // Free the initializer if this is the last usage of it. if (initializerToLastInputIndexMap[initializer] == i) { @@ -592,9 +601,11 @@ namespace DmlGraphFusionHelper for (auto& kvp : isInitializerTransferable) { + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, kvp.second.first); + ONNX_NAMESPACE::TensorProto tensorProto; tensorProto.set_data_type(kvp.second.first->data_type()); - tensorProto.set_raw_data(kvp.second.first->raw_data()); + tensorProto.set_raw_data(tensorPtr, tensorByteSize); tensorProto.set_name(kvp.second.first->name()); for (int i = 0; i < kvp.second.first->dims_size(); ++i) From 90d1f537cb32f400f5abd16248c7abe1a2738484 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Mon, 30 Oct 2023 14:12:17 -0700 Subject: [PATCH 02/21] optimize SLN with large dimension (#18138) ### Description Optimize SkipLayerNorm for large dimension (>=2048) by handling 8 elements in one thread. It avoid the re-writing and re-loading sum of input, skip and bias to main memory. It reduces the latency of dimension 4096 with small batch size from ~18us to ~3.8us on A100. ### Motivation and Context --- .../cuda/bert/skip_layer_norm_impl.cu | 82 ++++++++++--------- 1 file changed, 45 insertions(+), 37 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index e4b09b00f030c..973ef8d304e2e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,11 +51,11 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048}; +constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192}; constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; -constexpr int kMaxBlockSize = 256; +constexpr int kMaxBlockSize = 1024; int NextSize(int x) { for (size_t i = 0; i < kNumOfSizes; ++i) { @@ -63,14 +63,13 @@ int NextSize(int x) { return kSizes[i]; } } - return kMaxSize; + return kMaxSize + 1; } -template -bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias, - const T* gamma, const T* beta, const int ld, const int next_size) { - constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && +bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias, + const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) { + int alignment = element_size * num_unroll; + return ld % num_unroll == 0 && reinterpret_cast(output) % alignment == 0 && reinterpret_cast(sum_output) % alignment == 0 && reinterpret_cast(input) % alignment == 0 && @@ -78,8 +77,8 @@ bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, cons reinterpret_cast(bias) % alignment == 0 && reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - next_size / NumUnroll >= kMinBlockSize && - next_size / NumUnroll <= kMaxBlockSize; + next_size / num_unroll >= kMinBlockSize && + next_size / num_unroll <= kMaxBlockSize; } } // namespace @@ -187,8 +186,14 @@ void LaunchSkipLayerNormKernel( int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); - bool flag_vec4 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); + bool can_unroll_vec4 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 4, sizeof(T)); + bool can_unroll_vec8 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 8, sizeof(T)); #define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ SkipLayerNormKernelSmall<<>>( \ @@ -198,39 +203,42 @@ void LaunchSkipLayerNormKernel( SkipLayerNormKernel<<>>( \ output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) -#define CASE_NEXT_SIZE(next_size_value) \ - case next_size_value: { \ - static_assert(next_size_value > kSizes[0] && next_size_value < kMaxSize); \ - if (flag_vec4) { \ - constexpr int block_size = next_size_value / 4; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ - } else if (flag_vec2) { \ - constexpr int block_size = next_size_value / 2; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ - } else { \ - if (next_size_value <= kMaxBlockSize) { \ - constexpr int block_size = next_size_value; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ - } else { \ - constexpr int block_size = 256; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ - } \ - } \ +#define CASE_NEXT_SIZE(next_size_value) \ + case next_size_value: { \ + static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ + if constexpr (next_size_value >= 8 * 256) { \ + if (can_unroll_vec8) { \ + constexpr int block_size = next_size_value / 8; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } else { \ + if (can_unroll_vec4) { \ + constexpr int block_size = next_size_value / 4; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ + } else { \ + if (next_size_value <= kMaxBlockSize) { \ + constexpr int block_size = next_size_value; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } \ + } \ } break switch (next_size) { - case kSizes[0]: { - constexpr int block_size = kSizes[0]; - // TODO: Add back the small TensorRT kernel for 32. No need to use vertorized kernel for such small size. - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); - break; - } + CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); CASE_NEXT_SIZE(kSizes[3]); CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); - // kMaxSize shall not run vectorized kernel since ld might be larger than kMaxSize. + CASE_NEXT_SIZE(kSizes[6]); + CASE_NEXT_SIZE(kSizes[7]); default: { constexpr int block_size = 256; LAUNCH_SKIP_LAYER_NORM_KERNEL(); From 785e2b1eaec2b2e69267cc9320f7749315c0e686 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 31 Oct 2023 07:05:35 +0800 Subject: [PATCH 03/21] [js/webgpu] Optimize softmax by vector (#18153) ### Description This PR enables `softmax` outputs max supported components instead of scalar for each thread. Softmax with input[0]: [12,4096,4096] becomes 47.86 ms from 55.11 ms --- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 44 +++++++++++++++------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index d4dbad79e613e..ec651ce34e8c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -10,7 +10,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {getMaxComponents, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { @@ -37,23 +37,39 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut const cols = shape[axis]; const rows = outputSize / cols; + const components = getMaxComponents(cols); + const packedCols = cols / components; + const valueType = components === 1 ? dataType : `vec${components}<${dataType}>`; + + const maxVector = (name: string, components: number) => { + if (components === 4) { + return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`; + } else if (components === 2) { + return `max(${name}.x, ${name}.y)`; + } else if (components === 3) { + return `max(max(${name}.x, ${name}.y), ${name}.z)`; + } + + return name; + }; // 6.2.4 in wgsl spec - const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;'; + const threadMaxDecl = + dataType === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (_shaderHelper: ShaderHelper) => ` - var rowMaxShared : ${dataType}; - var rowSumShared : ${dataType}; - var threadShared : array<${dataType}, ${WG}>; + var rowMaxShared : ${valueType}; + var rowSumShared : ${valueType}; + var threadShared : array<${valueType}, ${WG}>; - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var result : array<${dataType}>; + @group(0) @binding(0) var x : array<${valueType}>; + @group(0) @binding(1) var result : array<${valueType}>; - fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} { + fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} { let index = row * row_stride + col; return x[index]; } - fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) { + fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) { let index = row * row_stride + col; result[index] = value; } @@ -64,8 +80,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut let lindex = i32(local_id.x); const wg = ${WG}; let row = gindex / wg; - let cols = ${cols}; - let row_stride : i32 = ${cols}; + let cols = ${packedCols}; + let row_stride : i32 = ${packedCols}; // find the rows max ${threadMaxDecl} @@ -87,12 +103,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowMaxShared = threadShared[0]; + rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)}); } workgroupBarrier(); // find the rows sum - var threadSum: ${dataType} = 0.0; + var threadSum = ${valueType}(0.0); for (var col = lindex; col < cols; col += wg) { let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); threadSum += subExp; @@ -107,7 +123,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); } if (lindex == 0) { - rowSumShared = threadShared[0]; + rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)}); } workgroupBarrier(); From efef6407bcdd2464f7f38b09579fc6aeb331655e Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Tue, 31 Oct 2023 08:41:01 +0800 Subject: [PATCH 04/21] [ROCm] update rocm package exclude libs (#18130) update rocm package exclude libs. - change librocblas.so.0 to librocblas.so.3 which is used on ROCm5.6 and ROCm5.7 - add librocfft.so.0, libhipfft.so.0, libhiprtc.so.5 and sort the list. --- setup.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index f6308c56d0590..9eca9845c9e8b 100644 --- a/setup.py +++ b/setup.py @@ -203,19 +203,22 @@ def run(self): "libcurand.so.10", ] rocm_dependencies = [ - "librccl.so.1", - "libnuma.so.1", "libamd_comgr.so.2", + "libamdhip64.so.5", "libdrm.so.2", - "librocblas.so.0", "libdrm_amdgpu.so.1", - "libamdhip64.so.5", - "libroctracer64.so.4", - "libMIOpen.so.1", - "libtinfo.so.6", "libelf.so.1", - "librocm_smi64.so.5", + "libhipfft.so.0", + "libhiprtc.so.5", "libhsa-runtime64.so.1", + "libMIOpen.so.1", + "libnuma.so.1", + "librccl.so.1", + "librocblas.so.3", + "librocfft.so.0", + "librocm_smi64.so.5", + "libroctracer64.so.4", + "libtinfo.so.6", ] tensorrt_dependencies = ["libnvinfer.so.8", "libnvinfer_plugin.so.8", "libnvonnxparser.so.8"] From 8ed9bd6eca143c791c56c055b7a9b0755040a5f3 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 30 Oct 2023 21:21:51 -0700 Subject: [PATCH 05/21] Add one more MHA mask pattern (#18164) Add an MHA mask pattern for the scenario where the mask has already been broadcasted via an Expand node. --- .../transformers/fusion_rotary_attention.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 3c5029ac5752f..44d15b619ec7a 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -427,6 +427,16 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 2, 1, 0, 0, 0], ) + attn_mask_nodes_5 = self.model.match_parent_path( + add_qk, + ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 0, 2, 1, 0, 0, 0], + ) + attn_mask_nodes_6 = self.model.match_parent_path( + add_qk, + ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 2, 1, 0, 0, 0], + ) if attn_mask_nodes_1 is not None: _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 attn_mask = slice_mask_1.output[0] @@ -439,6 +449,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif attn_mask_nodes_4 is not None: # Reshape from (B,1,S,T) to (B,N,S,T) add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0]) + elif attn_mask_nodes_5 is not None: + # The mask has already been reshaped to (B,N,S,T) + add_qk_str = attn_mask_nodes_5[0].output[0] + elif attn_mask_nodes_6 is not None: + # The mask has already been reshaped to (B,N,S,T) + add_qk_str = attn_mask_nodes_6[0].output[0] else: logger.debug("fuse_rotary_attention: failed to match attention mask nodes") return From 8a574b874cf5093b9a1e5383046f6e71d9ada4cf Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Mon, 30 Oct 2023 21:28:02 -0700 Subject: [PATCH 06/21] Update setup_env_cuda.bat (#18176) ### Description ### Motivation and Context --- tools/ci_build/github/windows/setup_env_cuda.bat | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tools/ci_build/github/windows/setup_env_cuda.bat b/tools/ci_build/github/windows/setup_env_cuda.bat index 96569cbe0f648..2233f7611ab6a 100644 --- a/tools/ci_build/github/windows/setup_env_cuda.bat +++ b/tools/ci_build/github/windows/setup_env_cuda.bat @@ -1,15 +1,17 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ { - set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% -} else { +if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( +set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% +) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% -} +) + @REM The default version is still cuda v11.8, because set cuda v12.2 after it -if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ { +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 -} else { +) else ( set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 -} +) + set GRADLE_OPTS=-Dorg.gradle.daemon=false From 08dce542665598cf659447555f4e0eb6d96bf86a Mon Sep 17 00:00:00 2001 From: cloudhan Date: Tue, 31 Oct 2023 13:10:21 +0800 Subject: [PATCH 07/21] Improve tunable verbose log (#17328) --- ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 2 +- .../rocm/bert/gemm_fast_gelu_ck.cuh | 8 ++-- .../rocm/diffusion/group_norm_ck.cuh | 2 +- onnxruntime/core/framework/tunable.h | 39 ++++++++++--------- .../core/providers/rocm/math/softmax_ck.cuh | 2 +- .../core/providers/rocm/tunable/gemm_ck.cuh | 6 +-- .../providers/rocm/tunable/gemm_hipblaslt.h | 2 +- .../providers/rocm/tunable/gemm_rocblas.h | 9 ++--- 8 files changed, 35 insertions(+), 35 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 246b66078537a..78983ac95e672 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -838,7 +838,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { Nop{}); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); if constexpr (USE_MASK) { ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); diff --git a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh index cbf24ee2f5487..ea9040aa7875f 100644 --- a/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh @@ -58,7 +58,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { auto zero = ToHipType::FromFloat(0.0f); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias == nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr"); auto nop = Nop{}; auto addfastgelu = AddFastGelu{}; @@ -67,7 +67,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() { params->lda, params->ldb, std::array{0}, params->ldc, nop, nop, addfastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; @@ -95,7 +95,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero || params->bias != nullptr, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr"); auto nop = Nop{}; auto fastgelu = FastGelu{}; @@ -108,7 +108,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() { params->ldc, nop, nop, fastgelu); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index e87813fb19956..0146e81c6cf8c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -79,7 +79,7 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { nullptr, activation); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 96b4cc53a022c..6d2dd641f6bc6 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -232,14 +232,15 @@ class TunableOp { return timer.Duration() / num_iter; } - static bool IsSupported(Op& op, const ParamsT* param) { - Status status = op.IsSupported(param); + // Filter all Status, only OK and TUNABLE_OP_UNSUPPORTED is left, other error status will be thrown, and to be + // processed by onnxruntime. We return Status to avoid the construction of op and params signature string. + static Status IsSupported(Op& op, const ParamsT* params) { + Status status = op.IsSupported(params); if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) { - LOGS_DEFAULT(VERBOSE) << "unsupported reason: " << status.ErrorMessage(); - return false; + return status; } ORT_THROW_IF_ERROR(status); - return true; + return status; } protected: @@ -250,9 +251,9 @@ class TunableOp { int FindFastestImpl(const ParamsT* params, const std::vector>& candidates) { ITuningContext* ctx = params->TuningContext(); auto op_sig = Signature(); - auto param_sig = params->Signature(); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ')'; - auto min_time = std::numeric_limits::infinity(); + auto params_sig = params->Signature(); + LOGS_DEFAULT(VERBOSE) << "finding fastest for " << op_sig << '(' << params_sig << ')'; + auto min_duration_ms = std::numeric_limits::infinity(); int id = -1; constexpr const int max_tuning_iter = 100; @@ -260,30 +261,32 @@ class TunableOp { for (size_t i = 0; i < candidates.size(); i++) { auto& candidate = const_cast&>(candidates[i]); - if (!IsSupported(candidate, params)) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl found unsupported " << op_sig << '(' << param_sig << ") id=" << i; + auto status = IsSupported(candidate, params); + if (!status.IsOK()) { + LOGS_DEFAULT(VERBOSE) << "├──unsupported id=" << i << ", " << op_sig << '(' << params_sig << ")"; + LOGS_DEFAULT(VERBOSE) << "│ reason: " << status.ErrorMessage(); continue; } WarmUp(candidate, params); auto approx_duration = Profile(candidate, params, approx_num_iter); - if (approx_duration > 2 * min_time) { - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl skip slow instance " << op_sig << '(' << param_sig << ") id=" << i; + if (approx_duration > 2 * min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──skip slow instance id=" << i; continue; } int tuning_iter = std::max(1, int(std::min(double(max_tuning_iter), ctx->GetMaxTuningDurationMs() / approx_duration))); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl run instance " << op_sig << '(' << param_sig << ") id=" << i << " " << tuning_iter << " times."; - - auto time = Profile(candidate, params, tuning_iter); - if (time < min_time) { - min_time = time; + auto duration_ms = Profile(candidate, params, tuning_iter); + if (duration_ms < min_duration_ms) { + LOGS_DEFAULT(VERBOSE) << "├──found better instance, new best id=" << i << ", old id=" << id << ". " + << duration_ms << "ms, " << tuning_iter << " iters."; + min_duration_ms = duration_ms; id = static_cast(i); } } ORT_ENFORCE(id >= 0, "Could not find viable op"); - LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ") found fastest with id=" << id; + LOGS_DEFAULT(VERBOSE) << "└──found fastest with id=" << id << " for " << op_sig << '(' << params_sig << ")"; std::this_thread::sleep_for(std::chrono::milliseconds(50)); return id; } diff --git a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh index 5830c9dd0bf27..f87b436d04a17 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_ck.cuh +++ b/onnxruntime/core/providers/rocm/math/softmax_ck.cuh @@ -58,7 +58,7 @@ auto GetCKSoftmaxTypeStringAndOps() { auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, alpha, beta, params->input, params->output, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh index 86d023886cfaf..2518f45e0995e 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh +++ b/onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh @@ -61,7 +61,7 @@ auto GetCKGemmTypeStringAndOps() { params->lda, params->ldb, params->ldc, nop, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; @@ -164,7 +164,7 @@ auto GetCKStridedBatchedGemmTypeStringAndOps() { auto zero = ToHipType::FromFloat(0.0f); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->alpha != one || params->beta != zero, - impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature()); + impl->GetTypeString(), " only supports alpha == 1 and beta == 0"); auto nop = Nop{}; auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c, @@ -174,7 +174,7 @@ auto GetCKStridedBatchedGemmTypeStringAndOps() { params->batch, nop, nop, nop); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support ", params->Signature()); + impl->GetTypeString(), " does not support the params"); invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); return Status::OK(); }; diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index d5f9de26ada22..b9c0cdcc1c341 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -221,7 +221,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != HIPBLAS_STATUS_SUCCESS, - "[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported (", params->Signature(), ")"); + "[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported"); IAllocatorUniquePtr workspace_buffer; if (workspace_size > 0) { diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h index 8e894e63c5de1..a391d1af8868c 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h @@ -168,8 +168,7 @@ auto GetRocBlasGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; @@ -238,8 +237,7 @@ auto GetRocBlasBatchedGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; @@ -308,8 +306,7 @@ auto GetRocBlasStridedBatchedGemmTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( status != rocblas_status_success, - "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status), - " (", params->Signature(), ")"); + "[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status)); return Status::OK(); }; From 1c25fe55800b05ffd2d89d2c6f69b39820bc5c4f Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 31 Oct 2023 13:53:11 +0800 Subject: [PATCH 08/21] Fix PoliCheck (#18180) Fix PoliCheck by changing some words, which was from Triton flash attention's original code. --- .../python/training/ort_triton/kernel/_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py index 40398b33d8f04..03bb0f4373d8d 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py @@ -393,7 +393,7 @@ def _bwd_kernel_one_col_block( dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) # There seems to be some problem with Triton pipelining that makes results wrong for # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop - # may have zero step, and pipelining with the bias matrix could screw it up. + # may have zero step, and pipelining with the bias matrix could cause the problem. # So we just exit early. if begin_m >= seqlen_q: dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) From 29e40987e3bc5b8f37c14c8ea6b6eccb620dee42 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Tue, 31 Oct 2023 10:22:40 -0700 Subject: [PATCH 09/21] Update batch file to set PATH for Cuda with TRT (#18182) ### Description Update batch file to set PATH for Cuda with TRT ### Motivation and Context --- .../ci_build/github/windows/setup_env_gpu.bat | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat index 4328c6eba1fe1..49b536e6ab81e 100644 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ b/tools/ci_build/github/windows/setup_env_gpu.bat @@ -1,11 +1,21 @@ REM Copyright (c) Microsoft Corporation. All rights reserved. REM Licensed under the MIT License. -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ { +if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( set PATH=%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64;%PATH% -} else { +) else ( set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% -} -set PATH=C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Current\Bin;%PATH% +) +set PATH=C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib;%PATH% + +@REM The default version is still cuda v11.8, because set cuda v12.2 after it +set PATH=%PATH%;C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ ( + set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 +) else ( + set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\extras\CUPTI\lib64 +) + + set GRADLE_OPTS=-Dorg.gradle.daemon=false set CUDA_MODULE_LOADING=LAZY From 95f053c652ac1e7afb3d795ea5df2b165b2429c2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 31 Oct 2023 10:27:20 -0700 Subject: [PATCH 10/21] [CUDA] Update GroupNorm and Add SkipGroupNorm (#18091) * Add a new operator SkipGroupNorm to support skip and bias inputs. * Update GroupNorm kernel to support number of channels used in SD XLrefiner. * Add epsilon in kernel * Add parity and performance test script * Remove many limitations including max batch size, max number of groups, c % cPerBlock ==0 etc. ### Motivation and Context Update GroupNorm to support SD XL Refiner and beyond. --- docs/ContribOperators.md | 70 +- docs/OperatorKernels.md | 1 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../contrib_ops/cuda/diffusion/group_norm.cc | 124 +++- .../contrib_ops/cuda/diffusion/group_norm.h | 7 +- .../cuda/diffusion/group_norm_impl.cu | 604 ++++++++++++------ .../cuda/diffusion/group_norm_impl.h | 38 +- .../contrib_ops/rocm/diffusion/group_norm.cc | 6 + .../core/graph/contrib_ops/diffusion_defs.cc | 81 ++- onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../python/tools/symbolic_shape_infer.py | 6 + .../tools/transformers/io_binding_helper.py | 4 +- .../contrib_ops/skip_group_norm_op_test.cc | 286 +++++++++ .../python/transformers/test_group_norm.py | 541 ++++++++++++++++ 14 files changed, 1531 insertions(+), 241 deletions(-) create mode 100644 onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc create mode 100644 onnxruntime/test/python/transformers/test_group_norm.py diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 890403556cc47..ed1049b0bd73a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -95,6 +95,7 @@ Do not modify directly.* * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling + * com.microsoft.SkipGroupNorm * com.microsoft.SkipLayerNormalization * com.microsoft.SkipSimplifiedLayerNormalization * com.microsoft.Snpe @@ -2342,7 +2343,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation : int (required)
-
Activation after group normalization: 0 for None, 1 for Swish
+
Activation after group normalization: 0 for None, 1 for SiLU
channels_last : int
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
epsilon : float
@@ -2582,6 +2583,7 @@ This version of the operator has been available since version 1 of the 'com.micr Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + #### Version This version of the operator has been available since version 1 of the 'com.microsoft' operator set. @@ -5083,6 +5085,72 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.SkipGroupNorm** + + This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + + This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. + The num_channels must be divisible by num_groups. + The mean and standard-deviation of s are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for SiLU
+
channels_last : int
+
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs (4 - 5) + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
skip : T
+
4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)
+
bias (optional) : T
+
1D bias tensor. Dimensions are (C), where C is number of channels
+
+ +#### Outputs (1 - 2) + +
+
Y : T
+
The output tensor of the same shape as X
+
S (optional) : T
+
The element-wise sum of input x, skip and bias tensors. It has the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X, skip, bias and output Y, S types to float tensors.
+
M : tensor(float16), tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bfb7716dc5cea..dcdf73cbdbf08 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -861,6 +861,7 @@ Do not modify directly.* |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 2618fe4a238bd..d51915b85095f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -97,6 +97,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -269,6 +270,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 301b2e76b1b2d..87e88ac31c998 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/diffusion/group_norm.h" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" @@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX( GroupNorm, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); +ONNX_OPERATOR_KERNEL_EX( + SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + using namespace ONNX_NAMESPACE; namespace { + template struct DispatchGroupNorm { Status operator()(cudaStream_t stream, Tensor* output, + Tensor* add_out, const Tensor* input, + const Tensor* skip, + const Tensor* bias, const Tensor* gamma, const Tensor* beta, void* workspace, @@ -32,12 +39,17 @@ struct DispatchGroupNorm { int height, int width, int num_groups, - bool use_swish_activation) { + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( stream, reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), gamma->Data(), beta->Data(), workspace, @@ -47,13 +59,21 @@ struct DispatchGroupNorm { height, width, num_groups, - use_swish_activation); + use_swish_activation, + broadcast_skip, + channels_per_block); } }; } // namespace GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + has_skip_ = false; + const std::string& op_name = op_info.GetKernelDef().OpName(); + if (op_name == "SkipGroupNorm") { + has_skip_ = true; + } + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { use_swish_activation_ = (activation == 1); channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); + + channels_per_block_ = 0; +} + +Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + + // Compute and cache cPerBlock using number of channels from gamma tensor shape. + if (input_idx == 1) { + auto gamma_shape = tensor.Shape(); + if (gamma_shape.NumDimensions() == 1) { + channels_per_block_ = GetChannelsPerBlock(static_cast(gamma_shape[0]), num_groups_); + } + } + + return Status::OK(); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "only the channels_last layout is supported"); } + if (!gamma->IsDataType() || !beta->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm only supports gamma and beta in float type"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input is expected to have 4 dimensions, got ", input_dims.size()); } + // Only support NHWC format right now. + int batch_size = static_cast(input_dims[0]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + int num_channels = static_cast(input_dims[3]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[3]) { + if (gamma_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[3]) { + if (beta_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); - } - if (context->GetUseDeterministicCompute()) { static std::once_flag log_warning; std::call_once(log_warning, []() { @@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { }); } - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + const Tensor* skip = nullptr; + const Tensor* bias = nullptr; + Tensor* add_out = nullptr; + + bool broadcast_skip = false; + if (has_skip_) { + skip = context->Input(3); + bias = context->Input(4); + add_out = context->Output(1, input->Shape()); + + if (bias != nullptr) { // Bias is optional + // If provided, bias has shape (C). + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } + } + + // Check whether skip can be broadcasted to input shape. + if (skip->Shape() != input->Shape()) { + const auto& dims = skip->Shape().GetDims(); + // The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast. + const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels); + const bool b4 = (dims.size() == 4 && dims[0] == batch_size && + dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels); + broadcast_skip = b2 || b4; + if (!broadcast_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)"); + } + } + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), + context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + gamma, beta, workspace.get(), epsilon_, batch_size, num_channels, height, width, num_groups_, - use_swish_activation_); + use_swish_activation_, + broadcast_skip, + channels_per_block_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 52c006e6bdb96..b408b3c1ee79b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel { GroupNorm(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: - bool use_swish_activation_; + bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization? float epsilon_; int num_groups_; bool channels_last_; + bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm + int channels_per_block_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 01ba078b4be77..48b161552ce0c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -16,18 +16,45 @@ */ // The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { -static inline int32_t divUp(int32_t m, int32_t n) { +namespace { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} +} // namespace + +static inline int32_t DivUp(int32_t m, int32_t n) { return (m + n - 1) / n; } @@ -41,14 +68,14 @@ struct GroupSums { // The sum. float sum; // The sum of squares. - float sumSq; + float sum_sq; }; struct GroupSumsOp { inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { GroupSums dst; dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); dst.flag = a.flag + b.flag; return dst; } @@ -56,54 +83,85 @@ struct GroupSumsOp { template struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. + // The output buffer. Shape is (n, h, w, c). T* dst; - // The input buffer. Layout NHWC. + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + // The gamma scaling factor. float const* gamma; + // The beta term to add in GN. float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; // The number of instances in the batch. int32_t n; + // The height and width of each activation map. int32_t h; int32_t w; - // The number of channels. + + // Number of channels. int32_t c; - // The number of groups. + + // Number of groups. int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; + + // Do we apply the SiLU activation function? + bool use_silu; // Precomputed values and parameters to control the execution of the kernels. - // The number of activations per instance (h * w) and the number of - // activations per block. + // Number of activations per instance (h * w) int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; // The precomputed stride between instances. int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; // The precomputed number of groups per block. - int32_t groupsPerBlock; + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; }; template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[tTHREADS_PER_BLOCK]; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t hw_begin = blockIdx.y * params.hw_per_block; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); // The sums. float sum = 0.F; - float sumSq = 0.F; + float sum_sq = 0.F; // Iterate over the activations to compute the sums. - if (ci < params.c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - UpdateSum(params.src, offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); } } - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * 2 / params.cPerGroup; - int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - // Do the segmented scan. + // Do the segmented scan. InclusiveScan is not deterministic. GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == params.cPerGroup - 2) { //2 channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + if (threadIdx.x >= params.groups_per_block) { return; } - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } } template -void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the values are as we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); + // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); + // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCSumKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCSumKernel<<>>(params); break; - case 480: - groupNormNHWCSumKernel<<>>(params); + case 192: + GroupNormNHWCSumKernel<<>>(params); break; - case 256: - groupNormNHWCSumKernel<<>>(params); + case 160: + GroupNormNHWCSumKernel<<>>(params); break; case 128: - groupNormNHWCSumKernel<<>>(params); + GroupNormNHWCSumKernel<<>>(params); + break; + case 64: + GroupNormNHWCSumKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); template <> -__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo float2 f2 = __half22float2(h2); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo } template <> -__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - if (ci >= params.c) { +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { return; } // The instance in the batch. int32_t ni = blockIdx.z; - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / params.cPerGroup; + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; + float sum = 0.F, sum_sq = 0.F; if (gi < params.groups) { - sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; - sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; } - // Load gamma/beta. - float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.inv_hw_channels_per_group; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + float inv_std_dev = rsqrtf(var + params.epsilon); - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; - - // Fetch two channels per thread. - computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); } } template -void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCScaleKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCScaleKernel<<>>(params); break; - case 480: - groupNormNHWCScaleKernel<<>>(params); + case 192: + GroupNormNHWCScaleKernel<<>>(params); break; - case 256: - groupNormNHWCScaleKernel<<>>(params); + case 160: + GroupNormNHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + GroupNormNHWCScaleKernel<<>>(params); + break; + case 64: + GroupNormNHWCScaleKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; for (int32_t i = 1; i <= std::sqrt(n); i++) { if (n % i == 0) { int32_t divisor1 = n / i; int32_t divisor2 = i; - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; } } } - return maxDivisor; + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; } template Status LaunchGroupNormKernel( cudaStream_t stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -397,79 +588,94 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCParams params; - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } - GroupNormNHWCParams params; - int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; - switch (num_channels) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; + // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases + if (channels_per_block % channels_per_group != 0 || + channels_per_block > kMaxSize || + (channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in CUDA does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - params.withSwish = use_swish_activation; + params.use_silu = use_silu; params.dst = output; + params.add_out = add_out; params.src = input; + params.skip = skip; + params.bias = bias; params.gamma = gamma; params.beta = beta; - params.redBuffer = reinterpret_cast(workspace); + params.group_sum_buffer = reinterpret_cast(workspace); params.n = batch_size; params.h = height; params.w = width; params.c = num_channels; params.groups = num_groups; params.hw = params.h * params.w; - const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); - params.hwPerBlock = divUp(params.hw, blocksPerHW); - params.cPerBlock = cPerBlock; - params.cPerGroup = params.c / params.groups; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); + params.hw_per_block = DivUp(params.hw, blocks_per_hw); + + params.channels_per_block = channels_per_block; + params.channels_per_group = channels_per_group; params.hwc = params.hw * params.c; - params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); - params.groupsPerBlock = cPerBlock / params.cPerGroup; + params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); + params.groups_per_block = channels_per_block / params.channels_per_group; + params.epsilon = epsilon; + params.broadcast_skip = broadcast_skip; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("input", input, batch_size, num_channels, height * width); - DUMP_TENSOR("gamma", gamma, 1, num_channels); - DUMP_TENSOR("beta", beta, 1, num_channels); - cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); - groupNormNHWCSum(params, stream); - DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; + + params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); + + GroupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups); + + GroupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, + const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, + const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index c7e9245050ee6..9532aeecb2f57 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -12,29 +12,33 @@ namespace onnxruntime { namespace contrib { namespace cuda { -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { +constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) { // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; + return (sizeof(float) * 2) * batch_size * num_groups; } +int GetChannelsPerBlock(int num_channels, int num_groups); + template Status LaunchGroupNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization + T* output, // normalized output tensor. Shape is (n, h, w, c) + T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) + const T* input, // input tensor. Shape is (n, h, w, c) + const T* skip, // optional skip tensor. Shape is (n, h, w, c) + const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm + const float* gamma, // gamma (also known as weight or scale). Shape is (c) + const float* beta, // beta (also known as bias). Shape is (c) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization + bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast. + int channels_per_block // Pre-computed channels per block. ); } // namespace cuda diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc index c665da89af36c..e82e15a304f4c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } +Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c2f5edaa6149b..f81c3b8e0182c 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The number of groups of channels. It should be a divisor of the number of channels C", AttributeProto::INT) .Attr("activation", - "Activation after group normalization: 0 for None, 1 for Swish", + "Activation after group normalization: 0 for None, 1 for SiLU", AttributeProto::INT) .Attr("channels_last", "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", @@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +constexpr const char* SkipGroupNorm_ver1_doc = R"DOC( +This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + +This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. +The num_channels must be divisible by num_groups. +The mean and standard-deviation of s are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + SkipGroupNorm, 1, + OpSchema() + .SetDoc(SkipGroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", + AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for SiLU", + AttributeProto::INT) + .Attr("channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 " + " or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels," + " and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(3, + "skip", + "4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)", + "T") + .Input(4, + "bias", + "1D bias tensor. Dimensions are (C), where C is number of channels", + "T", + OpSchema::Optional) + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .Output(1, + "S", + "The element-wise sum of input x, skip and bias tensors. It has the same shape as X", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.") + .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } + + if (hasInputShape(ctx, 0)) { + propagateShapeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + })); + constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left tensor multiplies the Gelu activation result of right tensor. diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index aa31f3b5a7c62..b35cfc5d12f36 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul); @@ -205,6 +206,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef1c46b83946a..9b68aef57656e 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -200,6 +200,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFastGelu": self._infer_GemmFastGelu, "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, + "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, "MultiHeadAttention": self._infer_MultiHeadAttention, @@ -2376,6 +2377,11 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_SkipGroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node, 0, 0) + if len(node.output) > 1: + self._propagate_shape_and_type(node, 0, 1) + def _infer_BiasSplitGelu(self, node): # noqa: N802 input_shape = self._get_shape(node, 0) bias_shape = self._get_shape(node, 1) diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index de17f195c99cc..50703b9c17e03 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,6 +1,6 @@ import logging from collections import OrderedDict -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Union import numpy import torch @@ -229,7 +229,7 @@ def __del__(self): del self.io_binding del self.ort_session - def allocate_buffers(self, shape_dict: Dict[str, tuple]): + def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]): """Allocate tensors for I/O Binding""" if self.enable_cuda_graph: for name, shape in shape_dict.items(): diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc new file mode 100644 index 0000000000000..fefd5722054de --- /dev/null +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/provider_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace std; + +namespace onnxruntime { +namespace test { + +TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { + constexpr int64_t B = 2; + constexpr int64_t C = 16; + constexpr int64_t H = 2; + constexpr int64_t W = 2; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + -0.768555f, 1.575195f, -0.698242f, 1.587891f, 0.371826f, -0.280029f, -1.328125f, 0.127197f, + -0.197144f, 0.982422f, -0.671387f, -1.925781f, 1.800781f, -0.020218f, -0.782227f, 1.291992f, + -0.935059f, 1.782227f, -0.674316f, -1.943359f, -0.218994f, 0.054138f, -1.539062f, -0.546387f, + -2.160156f, 1.195312f, 1.653320f, -0.674316f, 0.224731f, -0.093262f, 1.160156f, -0.389404f, + 1.748047f, 0.766113f, 0.234375f, 0.011177f, -0.055847f, -0.930664f, -0.490234f, -0.655762f, + -0.382568f, -0.554688f, 0.910645f, -0.227295f, 1.687500f, 0.028397f, -0.241699f, -0.480957f, + -0.355713f, -2.095703f, -0.443359f, -0.126221f, -0.815918f, 0.792969f, -0.450439f, -0.952148f, + -1.174805f, 0.242798f, 0.138550f, -0.237061f, -0.994141f, 0.346436f, 0.147705f, 0.125854f, + -0.517090f, 0.253906f, 0.400146f, -0.540039f, -0.788574f, 0.146606f, -0.409668f, 0.281982f, + 1.444336f, 0.044434f, -0.366699f, 2.250000f, -0.453613f, -0.652344f, 1.828125f, -0.244751f, + 0.307129f, -0.051361f, 0.106384f, 0.844727f, 1.648438f, -0.904785f, -0.353760f, 0.510742f, + 0.074829f, -0.311279f, 0.274902f, 1.594727f, 1.367188f, 0.098755f, 0.043304f, -0.207397f, + 0.068298f, -0.601074f, 0.083008f, 0.264893f, -0.659180f, -0.216797f, -0.086548f, -0.683594f, + -0.964844f, -2.591797f, -0.817383f, -0.461914f, -1.840820f, -0.712402f, -0.052094f, -0.583008f, + 1.114258f, 0.190308f, 1.087891f, 0.005146f, 1.041992f, 1.363281f, -0.273682f, -0.465576f, + -0.027618f, 1.345703f, 0.789551f, -0.015991f, 0.401611f, 0.726562f, 0.598633f, 0.133667f}; + + std::vector gamma_data = { + 0.241255f, 0.556660f, -0.835532f, 0.564596f, -1.338308f, -0.278924f, 0.357326f, -1.745484f, + 0.277184f, 0.101415f, -0.018637f, -0.526188f, -0.011698f, -2.349411f, 0.206578f, 0.357679f}; + + std::vector beta_data = { + -1.194839f, 0.209146f, -0.677225f, -0.547338f, 1.275685f, -1.099577f, 0.470916f, 0.293907f, + -1.094209f, 2.350204f, -1.633769f, 0.248753f, -0.180166f, 0.365134f, -0.555731f, 1.843083f}; + + std::vector skip_data_nhwc = { + 0.892578f, -0.471924f, -0.423096f, 1.277344f, 0.257080f, -1.366211f, 1.552734f, 0.441406f, + -0.033142f, -0.059418f, 1.536133f, -0.225464f, 1.472656f, 0.591309f, -0.386230f, -2.197266f, + 0.089600f, -0.256592f, -1.873047f, 0.916992f, 0.392090f, 0.015526f, -0.949219f, 0.566895f, + -0.220459f, 1.262695f, -0.437744f, -2.283203f, -0.264893f, -0.660156f, 2.353516f, 1.992188f, + 0.865723f, -0.854004f, -1.014648f, 0.899414f, -1.041016f, 1.378906f, -0.075073f, -2.541016f, + -0.883789f, -0.428711f, 0.981934f, -0.072754f, 2.214844f, 0.658203f, 0.170166f, -1.727539f, + -0.672363f, -1.373047f, 0.318115f, 0.422363f, 0.260742f, -0.547852f, 0.545898f, -0.155762f, + 0.679688f, 2.861328f, -0.300781f, -0.504883f, 1.548828f, 0.353760f, -0.387695f, -1.595703f, + -0.170166f, -0.002897f, 0.273193f, -0.383545f, -1.082031f, -0.894043f, -1.048828f, -0.044708f, + 0.049286f, 0.220215f, 0.272705f, -0.853027f, -0.489258f, 0.513672f, 0.977051f, 0.310547f, + -0.577148f, -0.479004f, 0.838867f, 0.872559f, -0.510254f, 0.101807f, -0.299805f, -1.179688f, + -1.555664f, 0.668457f, 0.939453f, 0.118103f, -0.376709f, 0.735352f, -0.214233f, -1.987305f, + -0.931152f, 1.268555f, 1.427734f, -0.757812f, -1.324219f, 0.375488f, 1.364258f, -1.708008f, + 0.976562f, -0.037659f, -1.779297f, -0.196655f, 1.636719f, 0.690430f, 0.941895f, -1.882812f, + 0.431641f, 0.203857f, 1.306641f, -0.126343f, 1.408203f, 1.188477f, 0.432861f, -2.296875f, + -0.475342f, 1.517578f, -0.824219f, 1.288086f, -0.028244f, 1.918945f, 0.352295f, 0.693359f}; + + std::vector bias_data = { + -0.537598f, 0.500488f, -0.252441f, -0.460693f, -1.640625f, -1.298828f, 0.331787f, -1.588867f, + 1.000977f, 1.458984f, 0.702637f, 0.147827f, 1.143555f, 0.533691f, -0.072510f, 0.511230f}; + + std::vector norm_data_nhwc = { + -1.213867f, 0.856445f, -0.119141f, 0.386475f, 0.714355f, -0.804688f, + 1.048828f, -0.426270f, -1.091797f, 2.435547f, -1.641602f, 0.989746f, + -0.200928f, 0.267334f, -0.800781f, 1.577148f, -1.357422f, 1.000977f, + 0.613281f, -0.963867f, 1.179688f, -1.169922f, 0.308350f, 0.304199f, + -1.396484f, 2.513672f, -1.644531f, 1.206055f, -0.180664f, 1.896484f, + -0.294678f, 2.046875f, -0.844238f, 0.448486f, -0.294189f, -0.291504f, + 2.480469f, -1.250977f, 0.833008f, 4.593750f, -1.238281f, 2.335938f, + -1.651367f, 0.491943f, -0.204834f, 0.125610f, -0.682129f, 1.333984f, + -1.384766f, -0.708008f, -0.630859f, -0.504883f, 1.924805f, -1.208008f, + 1.013672f, 1.809570f, -1.128906f, 2.546875f, -1.631836f, 0.610840f, + -0.184326f, 0.110046f, -0.700195f, 1.471680f, -1.511719f, 0.492188f, + -0.847168f, -1.373047f, 2.837891f, -0.998047f, 0.521484f, 0.262207f, + -0.810547f, 2.400391f, -1.628906f, 0.049896f, -0.174927f, 1.076172f, + -0.252197f, 1.784180f, -1.418945f, 0.090820f, -1.056641f, 0.002945f, + 0.627441f, -0.989746f, 0.679199f, 1.130859f, -1.371094f, 2.408203f, + -1.645508f, -0.062988f, -0.192017f, -0.655762f, -0.718262f, 1.170898f, + -1.550781f, 0.706055f, -1.492188f, -1.148438f, 2.921875f, -1.136719f, + 1.058594f, 2.781250f, -1.089844f, 2.201172f, -1.597656f, 0.785645f, + -0.181396f, 0.868164f, -0.552246f, 1.097656f, -1.015625f, 0.565430f, + -2.173828f, -0.955078f, -0.336426f, -1.503906f, 0.838867f, 3.136719f, + -1.186523f, 2.580078f, -1.629883f, 0.094604f, -0.186523f, -3.884766f, + -0.542480f, 1.990234f}; + + std::vector add_out_data_nhwc = { + -0.414062f, 1.604492f, -1.374023f, 2.404297f, -1.011719f, -2.945312f, 0.556641f, -1.020508f, + 0.770508f, 2.382812f, 1.567383f, -2.003906f, 4.417969f, 1.105469f, -1.240234f, -0.394531f, + -1.382812f, 2.027344f, -2.800781f, -1.487305f, -1.466797f, -1.229492f, -2.156250f, -1.568359f, + -1.379883f, 3.917969f, 1.917969f, -2.808594f, 1.103516f, -0.219727f, 3.441406f, 2.113281f, + 2.076172f, 0.412598f, -1.033203f, 0.449951f, -2.738281f, -0.851562f, -0.233521f, -4.785156f, + -0.265625f, 0.475586f, 2.595703f, -0.152222f, 5.046875f, 1.220703f, -0.144043f, -1.697266f, + -1.566406f, -2.968750f, -0.377686f, -0.164551f, -2.195312f, -1.053711f, 0.427246f, -2.697266f, + 0.505859f, 4.562500f, 0.540527f, -0.594238f, 1.698242f, 1.233398f, -0.312500f, -0.958496f, + -1.224609f, 0.751465f, 0.420898f, -1.384766f, -3.511719f, -2.046875f, -1.126953f, -1.351562f, + 2.494141f, 1.724609f, 0.608398f, 1.544922f, 0.200684f, 0.395020f, 2.732422f, 0.577148f, + -0.807617f, -0.029785f, 0.692871f, 1.256836f, -0.502441f, -2.101562f, -0.321777f, -2.257812f, + -0.479492f, 1.816406f, 1.916992f, 1.860352f, 2.134766f, 1.367188f, -0.243408f, -1.683594f, + -1.400391f, 1.167969f, 1.257812f, -0.953613f, -3.625000f, -1.140625f, 1.609375f, -3.980469f, + 1.012695f, -1.170898f, -1.894531f, -0.510742f, 0.939453f, 0.511719f, 0.817383f, -1.955078f, + 1.007812f, 0.894531f, 2.142578f, -0.582031f, 0.809570f, 1.252930f, 0.490967f, -4.351562f, + 0.497803f, 4.320312f, 0.667969f, 1.419922f, 1.516602f, 3.179688f, 0.878906f, 1.337891f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array channels_last_values = {-1, 1}; + + for (const int channels_last : channels_last_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 4); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + test.AddInput("skip", dims_nhwc, ToFloat16(skip_data_nhwc)); + test.AddInput("bias", {C}, ToFloat16(bias_data)); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } +} + +TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { + constexpr int64_t B = 1; + constexpr int64_t C = 64; + constexpr int64_t H = 1; + constexpr int64_t W = 1; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + 0.588867f, 0.896484f, -0.213623f, 0.803223f, 0.659180f, -0.216187f, 1.197266f, -0.486084f, + -0.718750f, 0.332031f, -0.364746f, -0.831543f, -0.031219f, -1.059570f, 0.161621f, 1.519531f, + 0.169312f, 1.048828f, 1.330078f, 0.450195f, -2.867188f, -1.456055f, 0.708496f, -1.120117f, + -1.208984f, -1.199219f, -1.505859f, -0.549316f, 0.505371f, 0.723145f, -0.359131f, -0.250977f, + -0.879883f, -0.305664f, 0.709473f, 0.815430f, 0.617676f, -0.638672f, 0.066772f, -2.330078f, + -1.316406f, 1.744141f, 1.122070f, -0.633789f, -1.802734f, -0.825684f, 0.622559f, -0.481689f, + -1.364258f, -0.536621f, -0.464111f, 0.247437f, -0.213989f, 0.384521f, 0.556641f, -0.303711f, + -0.160034f, 0.882324f, -0.212036f, -0.796387f, 0.153076f, -1.311523f, 2.212891f, 0.685059f}; + + std::vector gamma_data = { + 0.789682f, 0.869051f, -0.010169f, -0.021685f, 0.506611f, 1.267444f, -0.312695f, 0.877844f, + 0.598637f, 0.598314f, -1.721544f, -0.593328f, 0.986705f, -0.419391f, -0.852584f, -0.572351f, + 0.912797f, -0.586863f, 0.477761f, -0.484418f, -0.193835f, 0.347757f, 0.327637f, -1.100304f, + 1.233108f, -0.272569f, -0.688656f, 0.687245f, 0.398386f, 0.888089f, -0.792587f, -0.769029f, + -0.427778f, 0.100768f, -2.187060f, 1.279301f, 1.109054f, 0.375992f, 1.514775f, 1.271436f, + 0.822896f, -0.476750f, 0.475507f, -1.011297f, 1.177197f, 1.586540f, -1.059944f, -0.145351f, + 0.841555f, -2.014113f, -0.230498f, 0.302128f, -0.180508f, 0.980534f, -0.126871f, 0.203151f, + -0.754841f, 0.420570f, -1.085798f, 1.335042f, -0.674930f, 2.453507f, 2.139259f, 1.087436f}; + + std::vector beta_data = { + -0.064518f, -0.262683f, 0.827528f, -0.960938f, 1.062519f, 2.417941f, 0.212789f, -1.638430f, + 1.875453f, -0.883058f, -0.006704f, 0.424894f, -0.869972f, 0.727008f, 0.879303f, -3.024141f, + -2.610873f, 1.269641f, 0.883006f, 0.804167f, -1.510324f, 2.258091f, -0.006750f, -1.553668f, + -1.659453f, 0.579603f, 0.652358f, 0.007077f, 0.099180f, 0.418658f, -0.273778f, -1.036199f, + -1.128691f, -0.296022f, -0.224056f, 1.476306f, 0.577624f, -0.372049f, -0.581659f, -1.841807f, + -0.361721f, 0.051160f, -0.749332f, -2.634807f, 0.562719f, -0.738667f, 0.024864f, -1.135937f, + -1.368144f, -1.458886f, -0.946683f, 1.953936f, -1.198661f, 0.166648f, 0.447206f, -0.458140f, + -0.553395f, 0.112900f, 0.255989f, -0.184551f, 1.254163f, -0.260479f, -1.232429f, 1.902575f}; + + std::vector skip_data = { + 0.952148f, 1.342773f, -0.172974f, -0.395264f, 1.119141f, 0.330566f, + 0.281494f, 0.472900f, -0.692871f, -0.634766f, 0.013504f, -1.866211f, + -0.428223f, 0.669922f, -0.323486f, 0.713867f, -0.350586f, 0.659180f, + -0.288574f, 0.324219f, -0.300781f, -0.789551f, -0.216431f, -0.221436f, + -0.086670f, 0.366211f, -0.643555f, -0.977051f, 0.001021f, 0.415527f, + -0.271729f, 0.836426f, 0.035370f, -0.806152f, 0.936035f, -0.021332f, + -1.095703f, 0.971680f, 1.648438f, 0.840820f, 0.837402f, 0.607910f, + -1.894531f, 0.666016f, -0.171143f, 1.625977f, -0.620117f, -0.039581f, + 1.702148f, -2.410156f, 1.565430f, -0.756348f, 1.446289f, 0.583496f, + -0.497559f, -0.271729f, -0.956055f, -1.642578f, 0.833496f, -1.136719f, + 1.248047f, -2.515625f, 0.080383f, 0.376221f}; + + std::vector norm_data_nhwc = { + 0.494873f, 1.017578f, 0.841797f, -0.949219f, 1.552734f, 1.333984f, 0.012703f, -2.511719f, + 1.424805f, -0.818359f, -0.128418f, 1.462891f, -0.882812f, 0.709961f, 0.693848f, -4.210938f, + -2.505859f, 0.513184f, 1.300781f, 0.460938f, -1.172852f, 1.851562f, 0.167969f, -0.885254f, + -2.535156f, 0.656738f, 1.683594f, -0.627441f, 0.478271f, 1.782227f, -0.196777f, -1.824219f, + -0.791016f, -0.398682f, -3.197266f, 2.275391f, 0.052704f, -0.286865f, 1.567383f, -3.552734f, + -0.646973f, -0.927734f, -1.032227f, -2.722656f, -1.337891f, 0.432129f, -0.040253f, -1.080078f, + -1.118164f, 3.123047f, -1.153320f, 1.843750f, -1.378906f, 0.941406f, 0.437256f, -0.542969f, + -0.218872f, 0.006115f, -0.265869f, -1.356445f, 0.649902f, -4.882812f, 1.696289f, 2.679688f}; + + std::vector add_out_data_nhwc = { + 1.541016f, 2.238281f, -0.386719f, 0.407959f, 1.778320f, 0.114380f, + 1.478516f, -0.013184f, -1.412109f, -0.302734f, -0.351318f, -2.697266f, + -0.459473f, -0.389648f, -0.161865f, 2.234375f, -0.181274f, 1.708008f, + 1.041016f, 0.774414f, -3.167969f, -2.246094f, 0.492188f, -1.341797f, + -1.295898f, -0.833008f, -2.148438f, -1.526367f, 0.506348f, 1.138672f, + -0.630859f, 0.585449f, -0.844727f, -1.111328f, 1.645508f, 0.793945f, + -0.478027f, 0.333008f, 1.714844f, -1.489258f, -0.479004f, 2.351562f, + -0.772461f, 0.032227f, -1.973633f, 0.800293f, 0.002441f, -0.521484f, + 0.337891f, -2.947266f, 1.101562f, -0.508789f, 1.232422f, 0.967773f, + 0.059082f, -0.575195f, -1.116211f, -0.760254f, 0.621582f, -1.933594f, + 1.401367f, -3.828125f, 2.292969f, 1.061523f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array has_add_out_values = {true, false}; + std::array skip_dims = {2, 4}; + + constexpr int channels_last = 1; + for (const int skip_dim : skip_dims) { + for (const bool has_add_out : has_add_out_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 8); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + if (skip_dim == 2) { + test.AddInput("skip", {B, C}, ToFloat16(skip_data)); + } else { + test.AddInput("skip", {B, 1, 1, C}, ToFloat16(skip_data)); + } + // no bias + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + + if (has_add_out) { + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + } + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_group_norm.py b/onnxruntime/test/python/transformers/test_group_norm.py new file mode 100644 index 0000000000000..bf295a65c8b53 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_group_norm.py @@ -0,0 +1,541 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import statistics +from dataclasses import dataclass +from enum import Enum +from time import perf_counter +from typing import Optional, Tuple + +import numpy +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + +torch.manual_seed(0) + + +class GroupNormOpType(Enum): + GROUP_NORM = 1 + SKIP_GROUP_NORM = 2 + + +@dataclass +class GroupNormConfig: + batch_size: int + height: int + width: int + channels: int + epsilon: float = 1e-5 + num_groups: int = 32 + activation: bool = False + channels_last: bool = True + fp16: bool = False + + op_type: GroupNormOpType = GroupNormOpType.GROUP_NORM + has_bias: bool = False + has_add_out: bool = False + broadcast_skip: int = 0 # 2 for (N, C), 4 for (N, 1, 1, C) + + def get_skip_symbolic_shape(self): + skip_shape = {0: ["N", "H", "W", "C"], 2: ["N", "C"], 4: ["N", 1, 1, "C"]} + return skip_shape[self.broadcast_skip] + + def get_skip_shape(self): + skip_shape = { + 0: [self.batch_size, self.height, self.width, self.channels], + 2: [self.batch_size, self.channels], + 4: [self.batch_size, 1, 1, self.channels], + } + return skip_shape[self.broadcast_skip] + + def broadcast(self, skip: torch.Tensor): + if self.broadcast_skip == 2: + return skip.reshape(self.batch_size, 1, 1, self.channels) + + return skip + + @staticmethod + def create( + b: int, + h: int, + w: int, + c: int, + fp16: bool = False, + activation: bool = False, + template: int = 0, + num_groups: int = 32, + ): + if template == 0: + return GroupNormConfig( + b, h, w, c, fp16=fp16, activation=activation, op_type=GroupNormOpType.GROUP_NORM, num_groups=num_groups + ) + + if template == 1: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 2: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=2, + num_groups=num_groups, + ) + + if template == 3: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=False, + broadcast_skip=4, + num_groups=num_groups, + ) + + if template == 4: # No bias + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 5: # No bias, no add_out + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=0, + num_groups=num_groups, + ) + + return None + + +def create_group_norm_graph(config: GroupNormConfig) -> bytes: + inputs = ["input", "gamma", "beta"] + outputs = ["output"] + op_type = "GroupNorm" + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + op_type = "SkipGroupNorm" + inputs = [*inputs, "skip"] + if config.has_bias: + inputs = [*inputs, "bias"] + if config.has_add_out: + outputs = [*outputs, "add_out"] + + nodes = [ + helper.make_node( + op_type, + inputs, + outputs, + op_type + "_0", + activation=int(config.activation), + channels_last=int(config.channels_last), + epsilon=config.epsilon, + groups=config.num_groups, + domain="com.microsoft", + ), + ] + + float_type = TensorProto.FLOAT16 if config.fp16 else TensorProto.FLOAT + + input_shapes = [ + helper.make_tensor_value_info("input", float_type, ["N", "H", "W", "C"]), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, ["C"]), + ] + output_shapes = [ + helper.make_tensor_value_info("output", float_type, ["N", "H", "W", "C"]), + ] + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + input_shapes = [ + *input_shapes, + helper.make_tensor_value_info("skip", float_type, config.get_skip_symbolic_shape()), + ] + if config.has_bias: + input_shapes = [*input_shapes, helper.make_tensor_value_info("bias", float_type, ["C"])] + if config.has_add_out: + output_shapes = [*output_shapes, helper.make_tensor_value_info("add_out", float_type, ["N", "H", "W", "C"])] + + graph = helper.make_graph( + nodes, + "Group_Norm_Graph", + input_shapes, + output_shapes, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def group_norm_ort( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, + measure_latency=False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]: + onnx_model_str = create_group_norm_graph(config) + ort_session = InferenceSession(onnx_model_str, providers=["CUDAExecutionProvider"]) + + session = CudaSession(ort_session, device=torch.device("cuda:0")) + + io_shape = { + "input": [config.batch_size, config.height, config.width, config.channels], + "gamma": [config.channels], + "beta": [config.channels], + "output": [config.batch_size, config.height, config.width, config.channels], + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + io_shape["skip"] = config.get_skip_shape() + if config.has_bias: + io_shape["bias"] = [config.channels] + if config.has_add_out: + io_shape["add_out"] = [config.batch_size, config.height, config.width, config.channels] + + session.allocate_buffers(io_shape) + + ort_inputs = { + "input": src, + "gamma": gamma, + "beta": beta, + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + ort_inputs["skip"] = skip + if config.has_bias: + ort_inputs["bias"] = bias + + ort_outputs = session.infer(ort_inputs) + output = ort_outputs["output"] + + add_out = ( + ort_outputs["add_out"] if config.op_type == GroupNormOpType.SKIP_GROUP_NORM and config.has_add_out else None + ) + + if measure_latency: + latency_list = [] + for _ in range(10000): + start_time = perf_counter() + session.infer(ort_inputs) + end_time = perf_counter() + latency_list.append(end_time - start_time) + average_latency = statistics.mean(latency_list) + return output, add_out, average_latency + + return output, add_out, None + + +def group_norm_torch( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + add_out = src + + if skip is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + config.broadcast(skip) + + if bias is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + bias.reshape(1, 1, 1, bias.shape[0]) + + x = add_out + if config.channels_last: + x = add_out.clone().permute(0, 3, 1, 2) # from NHWC to NCHW + + weight = gamma.to(x.dtype) + bias = beta.to(x.dtype) + output = torch.nn.functional.group_norm(x, config.num_groups, weight=weight, bias=bias, eps=config.epsilon) + + if config.activation: + torch.nn.functional.silu(output, inplace=True) + + if config.channels_last: + output = output.permute(0, 2, 3, 1) # from NCHW to NHWC + + return output, add_out + + +def print_tensor(name, tensor): + # Print in the format that could be directly added to unit tests in C++. + torch.set_printoptions(precision=6, sci_mode=False, linewidth=100, profile="full", threshold=1000) + print(name) + if tensor is not None: + print("shape", tensor.shape) + text = str(tensor.clone().flatten()) + print(text.replace("[", "[\n").replace("]", ",\n]").replace(",", "f,")) + else: + print(tensor) + + +def run_parity(config, measure_latency=True, verbose=False): + float_type = torch.float16 if config.fp16 else torch.float32 + + input_tensor = torch.randn( + config.batch_size, + config.height, + config.width, + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + gamma = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + beta = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + skip = None + bias = None + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + skip = torch.randn( + *config.get_skip_shape(), + device="cuda", + dtype=float_type, + requires_grad=False, + ) + if config.has_bias: + bias = torch.randn( + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + if verbose: + print(config) + print_tensor("input", input_tensor) + print_tensor("gamma", gamma) + print_tensor("beta", beta) + print_tensor("skip", skip) + print_tensor("bias", bias) + + out_ort, ort_add_out, latency = group_norm_ort( + input_tensor, gamma, beta, skip, bias, config, measure_latency=measure_latency + ) + + if verbose: + print_tensor("out_ort", out_ort) + print_tensor("ort_add_out", ort_add_out) + + torch_out, torch_add_out = group_norm_torch(input_tensor, gamma, beta, skip, bias, config) + + if verbose: + print_tensor("torch_out", torch_out) + print_tensor("torch_add_out", torch_add_out) + + average_diff = numpy.mean(numpy.abs(out_ort.detach().cpu().numpy() - torch_out.detach().cpu().numpy())) + + is_close = numpy.allclose( + out_ort.detach().cpu().numpy(), + torch_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + + is_add_out_close = ( + numpy.allclose( + ort_add_out.detach().cpu().numpy(), + torch_add_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + if ort_add_out is not None + else "" + ) + + # Compare results + print( + config.op_type.name, + " B:", + config.batch_size, + " H:", + config.height, + " W:", + config.width, + " C:", + config.channels, + " G:", + config.num_groups, + " activation:", + int(config.activation), + " channels_last:", + int(config.channels_last), + " fp16:", + int(config.fp16), + f" Latency(μs): {int(latency * 1e6)}" if isinstance(latency, float) else "", + " AvgDiff:", + average_diff, + " Pass:", + is_close, + is_add_out_close, + ) + + +def get_latent_height_width(): + default_size = [(512, 512), (768, 768), (1024, 1024)] + small_img_size = [(512, 768), (768, 512)] + xl_img_size = [ + (1152, 896), + (896, 1152), + (1216, 832), + (832, 1216), + (1344, 768), + (768, 1344), + (1536, 640), + (640, 1536), + ] + return [(int(h / 8), int(w / 8)) for (h, w) in default_size + small_img_size + xl_img_size] + + +def get_channels(): + return [128, 256, 512, 1024, 2048, 320, 640, 960, 1920, 2560, 384, 768, 1536, 3072, 1152, 2304] + + +def run_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm with Silu Activation for ", "fp16" if fp16 else "fp32") + for b in [2]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, activation=True, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_no_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm without Activation for ", "fp16" if fp16 else "fp32") + for b in [1, 2, 4]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_all_groups(template: int, fp16, measure_latency=False): + group_sizes = [1, 2, 4, 8, 16, 32] + print("Test GroupNorm for different group sizes:", group_sizes) + for group_size in group_sizes: + for h, w in get_latent_height_width()[:3]: + for c in get_channels()[:2]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=group_size, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_odd_channels(template: int, fp16, measure_latency=False): + # Test some random number of channels that can be divisible by 2 * num_groups + for h, w in get_latent_height_width(): + for c in [448, 704, 832, 1664, 2240, 2688, 2880, 3008]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_small_inputs(template: int, fp16): + config = GroupNormConfig.create(2, 2, 2, 16, fp16=fp16, activation=False, num_groups=4, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=False, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=True, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + +def run_performance(fp16): + # Run perf test to tune parameters for given number of channels. + for h, w in get_latent_height_width()[:3]: + for c in get_channels(): + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=0) + run_parity(config, measure_latency=True) + + +def run_all(template: int): + for fp16 in [True, False]: + run_small_inputs(template, fp16) + run_odd_channels(template, fp16) + run_all_groups(template, fp16) + run_activation(template, fp16) + run_no_activation(template, fp16) + + +def run_not_implemented(): + # Expect failure. Check whether the error message is expected. + try: + config = GroupNormConfig(1, 2, 2, 513, num_groups=3) + run_parity(config) + except RuntimeError as e: + assert "GroupNorm in CUDA does not support the input: n=1 h=2 w=2 c=513 groups=3" in str(e) + + +def main(): + run_performance(True) + + run_not_implemented() + + for template in range(6): + run_all(template) + + +if __name__ == "__main__": + main() From 20f2dd8b6ba1c461f9a8d90a578178eab1ff20f7 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Tue, 31 Oct 2023 14:58:21 -0700 Subject: [PATCH 11/21] use onnx rel-1.15.0, update cgman, cmake/external and requirement hash (#18177) --- cgmanifests/generated/cgmanifest.json | 12 +----------- cmake/deps.txt | 2 +- cmake/external/onnx | 2 +- .../azure-pipelines/templates/download-deps.yml | 4 ++-- .../x64/python/cpu/scripts/requirements.txt | 2 +- .../linux/docker/scripts/manylinux/requirements.txt | 2 +- .../github/linux/docker/scripts/requirements.txt | 2 +- 7 files changed, 8 insertions(+), 18 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index f9501253661a2..6b0e3659bd234 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -26,7 +26,7 @@ "component": { "type": "git", "git": { - "commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40", + "commitHash": "b86cc54efce19530fb953e4b21f57e6b3888534c", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "git submodule at cmake/external/onnx" @@ -192,16 +192,6 @@ "comments": "mp11" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6a20ba82b439ea1fd650da4d389e96b60a1dd828", - "repositoryUrl": "https://github.com/onnx/onnx.git" - }, - "comments": "onnx" - } - }, { "component": { "type": "git", diff --git a/cmake/deps.txt b/cmake/deps.txt index 631d326e2ba5b..aeb7c05080abb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e +onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa diff --git a/cmake/external/onnx b/cmake/external/onnx index 6a20ba82b439e..b86cc54efce19 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 6a20ba82b439ea1fd650da4d389e96b60a1dd828 +Subproject commit b86cc54efce19530fb953e4b21f57e6b3888534c diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 1373381e4c83e..0f6310724e9a1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.97 + version: 1.0.104 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.97 + version: 1.0.104 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index 5341ae062d332..680b12602910e 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index b2893286803b0..8ef1fd4522973 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 5d48a93b09c90..5673bddfe058a 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -5,7 +5,7 @@ mypy pytest setuptools>=68.2.2 wheel>=0.35.1 -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx argparse sympy==1.12 flatbuffers From ed41a2836c7963f4c46a073ea7bc29f971e06618 Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Tue, 31 Oct 2023 22:48:32 +0000 Subject: [PATCH 12/21] Fix cast removal bug (#17953) The `RemoveDuplicateCastTransformer` fairly naively removed Cast nodes from the graph without considering precision loss when using the same `TypeGroup`. For instance, F64 -> F32 -> F64 would be optimised out of the graph. I also noticed that signedness was not accounted for, which is not covered by any existing issue but is a problem. For example doing int -> unsigned int -> int produces very different values for negative inputs and so should not be optimised out One could argue that we shouldn't be performing such cast elimination at all (at least not in this transformer). The original scope might be well restricted to only eliminating unnecessary casts from the `InsertCastTransformer` and no others. ### Motivation and Context This should fix https://github.com/microsoft/onnxruntime/issues/17565, ttps://github.com/microsoft/onnxruntime/issues/9915 and https://github.com/microsoft/onnxruntime/issues/8787. --- .../core/optimizer/insert_cast_transformer.cc | 86 +++++++++++++++---- .../framework/insert_cast_transformer_test.cc | 65 ++++++++++++++ 2 files changed, 133 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 7c087ec77d9fe..959fcd6efdc3c 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -32,7 +32,7 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, int64_t to_type, onnxruntime::ProviderType providerType) { // insert cast op to cast input - std::string node_name = graph.GenerateNodeName("InsertedCast_" + old_arg->Name()); + std::string node_name = graph.GenerateNodeName("InsertedPrecisionFreeCast_" + old_arg->Name()); auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type); @@ -235,7 +235,8 @@ enum TypeGroup { Unknown = -1, Bool = 0, Integer = 1, - Float = 2, + Unsigned = 2, + Float = 3, }; TypeGroup GetTypeGroup(DataType type) { @@ -243,11 +244,14 @@ TypeGroup GetTypeGroup(DataType type) { return Bool; } - if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || - *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") { return Integer; } + if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Unsigned; + } + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { return Float; } @@ -255,6 +259,22 @@ TypeGroup GetTypeGroup(DataType type) { return Unknown; } +int BitLength(DataType type) { + if (*type == "tensor(bool)") { + return 1; + } else if (*type == "tensor(uint8)" || *type == "tensor(int8)") { + return 8; + } else if (*type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") { + return 16; + } else if (*type == "tensor(int32)" || *type == "tensor(uint32)" || *type == "tensor(float)") { + return 32; + } else if (*type == "tensor(int64)" || *type == "tensor(uint64)" || *type == "tensor(double)") { + return 64; + } else { + return -1; + } +} + /** Transformer to remove duplicate Cast nodes. */ class RemoveDuplicateCastTransformer : public GraphTransformer { public: @@ -262,6 +282,48 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } private: + static bool UnsafeCast(DataType src_type, DataType dst_type, const Node& node) { + // This is not a complete cast optimisation pass, and is more conservative than it could be. + // For instance, certain integral -> floating point casts could be optimised but this is left to an explicit cast optimisation pass. + + // The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer. + // Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support. + auto src_type_group = GetTypeGroup(src_type); + auto dst_type_group = GetTypeGroup(dst_type); + if (Unknown == src_type_group || Unknown == dst_type_group) { + return true; + } + + // Do not remove any signed -> unsigned cast. + if ((src_type_group != Bool && src_type_group != Unsigned) && Unsigned == dst_type_group) { + return true; + } + + // Do not remove any floating point -> non floating point cast. + if (Float == src_type_group && Float != dst_type_group) { + return true; + } + + auto src_bit_length = BitLength(src_type); + auto dst_bit_length = BitLength(dst_type); + + // unsigned integer -> integer cast may overflow if the destination integer is smaller or equal to the source integer. + if (Unsigned == src_type_group && Integer == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + // integral -> floating cast may overflow if integer cannot be encoded in the mantissa. This check could be more precise. + if ((Integer == src_type_group || Unsigned == src_type_group) && Float == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + return true; + } + + return src_bit_length > dst_bit_length && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")); + } + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override { auto output_args = graph.GetOutputs(); InlinedHashSet graph_outputs; @@ -293,17 +355,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { // - for each consumer cast node, it meets above condition for this optimization. auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); - TypeGroup src_type_group = GetTypeGroup(src_type); - TypeGroup dst_type_group = GetTypeGroup(dst_type); - if (src_type_group == Unknown || dst_type_group == Unknown) { - continue; - } - - bool loss_precision_cast = false; - if (src_type_group > dst_type_group) { - loss_precision_cast = true; - } + bool loss_precision_cast = UnsafeCast(src_type, dst_type, node); size_t num_children = node.GetOutputEdgesCount(); bool inconsistent_casts = false; @@ -312,10 +365,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { if (output_node.OpType() == "Cast") { auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); - TypeGroup src_type_group1 = GetTypeGroup(src_type1); - TypeGroup dst_type_group1 = GetTypeGroup(dst_type1); - if (src_type_group1 == Unknown || dst_type_group1 == Unknown || - (loss_precision_cast && dst_type_group1 > src_type_group1)) { + if (loss_precision_cast && UnsafeCast(dst_type1, src_type1, output_node)) { inconsistent_casts = true; break; } diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index c38baee39216b..1804c09043c7b 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -4,6 +4,7 @@ #include "core/framework/allocator.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/graph/model.h" +#include "core/graph/node_attr_utils.h" #include "gtest/gtest.h" #include "test_utils.h" #include "test/test_environment.h" @@ -110,6 +111,70 @@ TEST(TransformerTest, InsertCastAllCPUTest) { } } +TEST(TransformerTest, CastRemovalDoesNotLowerPrecisionTest) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + TypeProto tensor_float_32; + tensor_float_32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + TypeProto tensor_float_64; + tensor_float_64.mutable_tensor_type()->set_elem_type(TensorProto_DataType_DOUBLE); + onnxruntime::NodeArg n1_def("N1", &tensor_float_64), + n2_def("N2", &tensor_float_32), + n3_def("N3", &tensor_float_64); + + NodeAttributes n1_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT))}}; + NodeAttributes n2_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))}}; + + graph.AddNode("node1", "Cast", "F64 to F32 cast", ArgMap{&n1_def}, ArgMap{&n2_def}, &n1_attrs); + graph.AddNode("node2", "Cast", "F32 to F64 cast", ArgMap{&n2_def}, ArgMap{&n3_def}, &n2_attrs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + InsertCastTransformer cast_inserter("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + status = cast_inserter.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // When casting f64 -> f32 -> f64 we should not be optimising away the cast since there is a loss of precision. + EXPECT_EQ(graph.NumberOfNodes(), 2); +} + +TEST(TransformerTest, CastRemovalDoesNotRemoveSignednessTest) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + TypeProto tensor_uint32; + tensor_uint32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_UINT32); + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + onnxruntime::NodeArg n1_def("N1", &tensor_int32), + n2_def("N2", &tensor_uint32), + n3_def("N3", &tensor_int32); + + NodeAttributes n1_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_UINT32))}}; + NodeAttributes n2_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32))}}; + + graph.AddNode("node1", "Cast", "I32 to UI32 cast", ArgMap{&n1_def}, ArgMap{&n2_def}, &n1_attrs); + graph.AddNode("node2", "Cast", "UI32 to I32 cast", ArgMap{&n2_def}, ArgMap{&n3_def}, &n2_attrs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + InsertCastTransformer cast_inserter("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + status = cast_inserter.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // When casting i32 -> ui32 -> i32 we should not be optimising away the cast since applying the casts produces a very different result. + EXPECT_EQ(graph.NumberOfNodes(), 2); +} + // test that when there are 3 Cast ops in a row we remove the correct ones TEST(TransformerTest, ThreeInARowRemoval) { auto model_uri = MODEL_FOLDER ORT_TSTR("triple-cast.onnx"); From 62c7894ffe15efb7d43d891a326c2cbdcfbb529d Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 1 Nov 2023 09:25:48 +1000 Subject: [PATCH 13/21] Add mobile CIs to list run by script for external PRs. (#18094) ### Description Add the mobile CIs to the list so we check external PRs don't break those. ### Motivation and Context Recent external PR was found to break iOS CI after checkin --- tools/python/run_CIs_for_external_pr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index dcc6a92d84ef2..7a77839c4a4e7 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -93,6 +93,10 @@ def main(): # checks "onnxruntime-python-checks-ci-pipeline", "onnxruntime-binary-size-checks-ci-pipeline", + # not currently required, but running ensures we're hitting all mobile platforms + "Android CI Pipeline", + "iOS CI Pipeline", + "ONNX Runtime React Native CI Pipeline", ] # remove pipelines that have already run successfully From 2b95e74fa113ec168a79974987b2c6b98cecf700 Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Tue, 31 Oct 2023 16:50:27 -0700 Subject: [PATCH 14/21] Versioning for custom op (#18088) Allow custom ops to have versions. --------- Co-authored-by: Randy Shuai --- .../core/session/onnxruntime_c_api.h | 4 ++ .../core/session/onnxruntime_cxx_api.h | 13 ++++ .../core/session/onnxruntime_lite_custom_op.h | 59 ++++++++++++++----- onnxruntime/core/session/custom_ops.cc | 22 ++++++- onnxruntime/test/shared_lib/test_inference.cc | 16 +++++ .../testdata/custom_op_library/cpu/cpu_ops.cc | 37 +++++++----- .../test/testdata/fuse_select_filter.onnx | 5 +- .../testdata/fuse_select_filter_opset_8.onnx | 29 +++++++++ 8 files changed, 148 insertions(+), 37 deletions(-) create mode 100644 onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4a63018f870a6..613c1ac93cf1b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4605,6 +4605,10 @@ struct OrtCustomOp { OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); + + // Get start range + int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); + int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 467eb31ee2c8e..92c25d8688b66 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2228,6 +2228,8 @@ struct ShapeInferContext { using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); +#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 + template struct CustomOpBase : OrtCustomOp { CustomOpBase() { @@ -2280,6 +2282,14 @@ struct CustomOpBase : OrtCustomOp { } SetShapeInferFn(0); + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->end_ver_; + }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider @@ -2348,6 +2358,9 @@ struct CustomOpBase : OrtCustomOp { protected: // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index b12221e56b79f..443710884743a 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -773,8 +773,11 @@ struct OrtLiteCustomOp : public OrtCustomOp { PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) OrtLiteCustomOp(const char* op_name, - const char* execution_provider) : op_name_(op_name), - execution_provider_(execution_provider) { + const char* execution_provider, + int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -837,6 +840,16 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtCustomOp::KernelCompute = {}; OrtCustomOp::InferOutputShapeFn = {}; + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->end_ver_; + }; } const std::string op_name_; @@ -844,6 +857,9 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -873,9 +889,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFn compute_fn, - ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_(compute_fn), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -911,9 +929,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFnReturnStatus compute_fn_return_status, - ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_return_status_(compute_fn_return_status), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -985,8 +1005,9 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { }; OrtLiteCustomStruct(const char* op_name, - const char* execution_provider) : OrtLiteCustomOp(op_name, - execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { @@ -1049,25 +1070,31 @@ template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, void (*custom_compute_fn)(Args...), - Status (*shape_infer_fn)(ShapeInferContext&) = {}) { + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, Status (*custom_compute_fn_v2)(Args...), - Status (*shape_infer_fn)(ShapeInferContext&) = {}) { + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, - const char* execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomStruct; - return std::make_unique(op_name, execution_provider).release(); + return std::make_unique(op_name, execution_provider, start_ver, end_ver).release(); } } // namespace Custom diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 041250adc3fc0..b827c28f129b1 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -25,6 +25,7 @@ #if !defined(ORT_MINIMAL_BUILD) static constexpr uint32_t min_ort_version_with_optional_io_support = 8; static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; +static constexpr uint32_t min_ort_version_with_custom_version = 17; #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -698,8 +699,19 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust KernelDefBuilder def_builder; def_builder.SetName(op->GetName(op)) - .SetDomain(domain) - .SinceVersion(1); + .SetDomain(domain); + + if (op->version >= min_ort_version_with_custom_version) { + if (op->GetStartVersion && op->GetEndVersion) { + def_builder.SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op)); + } else if (op->GetStartVersion) { + def_builder.SinceVersion(op->GetStartVersion(op)); + } else { + def_builder.SinceVersion(1); + } + } else { + def_builder.SinceVersion(1); + } // GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions // to work with newer versions (> 12) of the ORT binary. @@ -820,7 +832,11 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom schema.TypeConstraint(output_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types"); } schema.SetDomain(domain); - schema.SinceVersion(1); + if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) { + schema.SinceVersion(op->GetStartVersion(op)); + } else { + schema.SinceVersion(1); + } schema.AllowUncheckedAttributes(); if (op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn) { diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ba282193c5ca6..33d50f90333cf 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3323,6 +3323,22 @@ TEST(LiteCustomOpTest, CustomFunc) { ASSERT_TRUE(floats_output[1] == 16); } +TEST(LiteCustomOpTest, CustomFuncOpsetMismatch) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.SetLogSeverityLevel(0); +#if defined(_WIN32) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); +#elif defined(__APPLE__) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); +#else + session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); +#endif + + EXPECT_THROW(Ort::Session(*ort_env, TSTR("testdata/fuse_select_filter_opset_8.onnx"), session_options), std::exception); +} + struct Merge { Merge(const OrtApi* ort_api, const OrtKernelInfo* info) { int64_t reverse; diff --git a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc index ad99b675c7d20..85edfa0e59f1d 100644 --- a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc @@ -94,23 +94,28 @@ void Select(const Ort::Custom::Span& indices_in, } } -void Filter(const Ort::Custom::Tensor& floats_in, - Ort::Custom::Tensor& floats_out) { - const float* in = floats_in.Data(); - auto in_len = floats_in.NumberOfElement(); +struct Filter { + Filter(const OrtApi*, const OrtKernelInfo*) {} + Ort::Status Compute(const Ort::Custom::Tensor& floats_in, + Ort::Custom::Tensor& floats_out) { + const float* in = floats_in.Data(); + auto in_len = floats_in.NumberOfElement(); + + std::vector filter_floats; + for (int64_t i = 0; i < in_len; ++i) { + if (in[i] > 1.f) { + filter_floats.push_back(in[i]); + } + } - std::vector filter_floats; - for (int64_t i = 0; i < in_len; ++i) { - if (in[i] > 1.f) { - filter_floats.push_back(in[i]); + float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); + for (size_t j = 0; j < filter_floats.size(); ++j) { + out[j] = filter_floats[j]; } - } - float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); - for (size_t j = 0; j < filter_floats.size(); ++j) { - out[j] = filter_floats[j]; + return Ort::Status{nullptr}; } -} +}; void Box(const Ort::Custom::Tensor* float_in_1, const Ort::Custom::Tensor* float_in_2, @@ -293,9 +298,9 @@ void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; static const std::unique_ptr c_MulTopOpFloat{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; static const std::unique_ptr c_MulTopOpInt32{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; - static const std::unique_ptr c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse)}; + static const std::unique_ptr c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse, {}, 10, 12)}; static const std::unique_ptr c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)}; - static const std::unique_ptr c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; + static const std::unique_ptr c_Filter{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", 15, 17)}; static const std::unique_ptr c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)}; static const std::unique_ptr c_CopyTensorArrayAllVariadic{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayAllVariadic", "CPUExecutionProvider", CopyTensorArrayAllVariadic)}; static const std::unique_ptr c_CopyTensorArrayCombined{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayCombined", "CPUExecutionProvider", CopyTensorArrayCombined)}; @@ -314,7 +319,7 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_MulTopOpInt32.get()); domain.Add(c_Fuse.get()); domain.Add(c_Select.get()); - domain.Add(c_Fill.get()); + domain.Add(c_Filter.get()); domain.Add(c_Box.get()); domain.Add(c_CopyTensorArrayAllVariadic.get()); domain.Add(c_CopyTensorArrayCombined.get()); diff --git a/onnxruntime/test/testdata/fuse_select_filter.onnx b/onnxruntime/test/testdata/fuse_select_filter.onnx index 15d7dd64788d3..0b881228edb9d 100644 --- a/onnxruntime/test/testdata/fuse_select_filter.onnx +++ b/onnxruntime/test/testdata/fuse_select_filter.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä P vector_1 vector_2 @@ -25,4 +25,5 @@ N ÿÿÿÿÿÿÿÿÿb& vector_filtered  - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file + ÿÿÿÿÿÿÿÿÿB +v2 \ No newline at end of file diff --git a/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx b/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx new file mode 100644 index 0000000000000..3ea27767eb9f5 --- /dev/null +++ b/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx @@ -0,0 +1,29 @@ + :Ä +P +vector_1 +vector_2 +alpha vector_fused fuse_node"Fuse* + fuse_algo :v2 +4 +indicesindices_selected select_node"Select:v2 +N + vector_fused +indices_selectedvector_gathered gather_node"GatherElements +; +vector_gatheredvector_filtered filter_node"Filter:v2graphZ +vector_1 + + ÿÿÿÿÿÿÿÿÿZ +vector_2 + + ÿÿÿÿÿÿÿÿÿZ +alpha + + ÿÿÿÿÿÿÿÿÿZ +indices + + ÿÿÿÿÿÿÿÿÿb& +vector_filtered + + ÿÿÿÿÿÿÿÿÿB +v2 \ No newline at end of file From d1b85f5fb4fff6fc674e50e2053039c7ded4969e Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Tue, 31 Oct 2023 17:53:52 -0700 Subject: [PATCH 15/21] Reduce LLaMA memory usage (#18181) ### Description This PR reduces the memory usage when exporting and benchmarking LLaMA. ### Motivation and Context - Exporting: The PyTorch model is deleted from memory after a successful export instead of deleting it from memory after exporting + converting the ONNX model to the desired precision. - Benchmarking: In the ONNX model with GroupQueryAttention, the KV cache inputs use the same GPU memory for both the prompt and token generation benchmarks. --- .../transformers/models/llama/benchmark.py | 104 +++---- .../models/llama/convert_to_onnx.py | 2 +- .../transformers/models/llama/llama_inputs.py | 271 +++++++++++++----- .../transformers/models/llama/llama_parity.py | 57 ++-- 4 files changed, 248 insertions(+), 186 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index a721979eb0bcb..245ff3dfe7f9d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,9 +11,8 @@ import onnx import psutil import torch -from benchmark_helper import setup_logger from llama_inputs import ( - convert_inputs_for_ort, + add_io_bindings, get_merged_sample_with_past_kv_inputs, get_msft_sample_inputs, get_sample_inputs, @@ -25,7 +24,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory +from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger logger = logging.getLogger(__name__) @@ -48,9 +47,19 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): init_inputs, iter_inputs = None, None # For past_present_share_buffer: - # Set max_seq_len to 2048 for Hugging Face model since that is the default value - # Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported - max_seq_len = 2048 + # Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2) + # Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value + # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_seq_len = ( + 2048 + if args.benchmark_type == "ort-msft" + else 16384 + if "codellama" in temp_name + else 4096 + if "llama2" in temp_name + else 2048 + ) if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( @@ -95,7 +104,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -104,7 +115,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) @@ -116,7 +129,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="ort", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -125,26 +140,10 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, - use_fp16=args.use_fp16, - return_dict=True, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + engine="ort", + return_dict=True, ) elif args.benchmark_type == "ort-msft": @@ -156,6 +155,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=0, seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, split_kv=split_kv, ) @@ -164,26 +164,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=args.sequence_length, seq_len=1, - use_fp16=args.use_fp16, - split_kv=split_kv, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + split_kv=split_kv, ) else: @@ -449,7 +432,7 @@ def get_logits(inputs): def run_ort_inference(args, init_inputs, iter_inputs, model): - def prepare_ort_inputs(inputs): + def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Check that all model inputs will be provided model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) user_inputs = set(inputs.keys()) @@ -467,29 +450,13 @@ def prepare_ort_inputs(inputs): # Add IO bindings for non-CPU execution providers if args.device != "cpu": - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.past_present_share_buffer: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.past_present_share_buffer and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) - + io_binding, kv_cache_ortvalues = add_io_bindings( + model, inputs, args.device, int(args.device_id), kv_cache_ortvalues + ) setattr(args, "io_binding", io_binding) # noqa: B010 - return io_binding + return io_binding, kv_cache_ortvalues - return inputs + return inputs, kv_cache_ortvalues def with_io_binding(io_binding): # Inference pass with IO binding @@ -501,9 +468,10 @@ def without_io_binding(inputs): return outputs generate_fn = with_io_binding if args.device != "cpu" else without_io_binding + kv_cache_ortvalues = {} if args.profile: - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt") # Turn profiling off to stop appending to log file @@ -513,7 +481,7 @@ def without_io_binding(inputs): # Re-initialize model for new log file instead of appending to old log file model = get_model(args) - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") # Turn profiling off to stop appending to log @@ -524,12 +492,12 @@ def without_io_binding(inputs): # ORT evaluations logger.info("\nEvaluating `model(inputs)` step to get past_key_values") - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_init_inputs) measure_fn(args, generate_fn, ort_init_inputs) logger.info("\nEvaluating `model(inputs)` step with past_key_values") - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_iter_inputs) measure_fn(args, generate_fn, ort_iter_inputs) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 69603fd3ed488..3f05be53c6729 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -716,6 +716,7 @@ def main(): run_torchscript_separate_export(args, l_config, llama) else: run_torchscript_merged_export(args, l_config, llama) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check # Set model paths to store FP32 optimized model decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") @@ -811,7 +812,6 @@ def main(): logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") remove_existing_model(fp_path) - del llama # Delete LLaMA model from memory since it will be loaded again during parity check logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 2652e9f0ca64e..f7a1b05249abf 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -4,7 +4,7 @@ import torch from transformers import LlamaConfig -from onnxruntime import OrtValue +from onnxruntime import InferenceSession, OrtValue # Get position_ids from attention_mask @@ -12,22 +12,36 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if use_past_kv: + # Shape: (batch_size, 1) position_ids = position_ids[:, -1].unsqueeze(-1) + + # Shape: (batch_size, sequence_length) return position_ids # Inputs for first pass to get initial past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, sequence_length) +# position_ids: (batch_size, sequence_length) def get_sample_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False + config: LlamaConfig, + device: torch.device, + batch_size: int, + seq_len: int, + engine: str = "pt", + return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64) - # position_ids is of shape (batch_size, seq_len) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) position_ids = get_position_ids(attention_mask, use_past_kv=False) + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + if not return_dict: + # For export return (input_ids, attention_mask, position_ids) inputs = { @@ -39,85 +53,192 @@ def get_sample_inputs( # Inputs for subsequent passes with past_key_values +# input_ids: (batch_size, 1) +# attention_mask: (batch_size, past_sequence_length + 1) +# position_ids: (batch_size, 1) +# past_key: (batch_size, num_heads, past_sequence_length, head_size) +# past_value: (batch_size, num_heads, past_sequence_length, head_size) def get_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, ): - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64) - attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs # Inputs for all passes with past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, past_sequence_length + sequence_length) +# position_ids: (batch_size, sequence_length) +# past_key: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length +# past_value: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length def get_merged_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, past_seq_len: int, + max_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + + if use_fp16: # If model has GQA + del inputs["attention_mask"] + inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) + inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) + + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs -# Create past_key_values -def get_sample_past_kv_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool +# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx +def get_msft_sample_inputs( + config: LlamaConfig, + batch_size: int, + past_seq_len: int, + seq_len: int, + max_seq_len: int, + use_fp16: bool, + split_kv: bool, ): + np_dtype = np.float16 if use_fp16 else np.float32 + head_size = config.hidden_size // config.num_attention_heads + + if not split_kv: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), + "k_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "v_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "pos": np.array(past_seq_len, dtype=np.int64), + } + else: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( + np.int32 + ), + "pos": np.array(past_seq_len, dtype=np.int64), + } + for i in range(config.num_hidden_layers): + ort_inputs.update( + { + f"k_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + f"v_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + } + ) + + if use_fp16: # If model has GQA + del ort_inputs["attn_mask"] + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) + + return ort_inputs + + +# Create past_key_values +# Each is of shape (batch_size, num_heads, past_sequence_length, head_size) +def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool): num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), ) for _ in range(config.num_hidden_layers) ] return past_kv -# Convert list of past_kv to dict of past_key and past_value -def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool): +# Convert list of past_key_values to dict of past_key and past_value +def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): past_kv = {} - np_dtype = np.float16 if use_fp16 else np.float32 for i, (past_k, past_v) in enumerate(past_key_values): - past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype) - past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype) + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() return past_kv @@ -136,7 +257,7 @@ def convert_inputs_for_ort( if isinstance(v, np.ndarray): ort_inputs[k] = v elif k == "past_key_values": - ort_inputs.update(flatten_past_kv_inputs(v, use_fp16)) + ort_inputs.update(flatten_past_kv_inputs(v)) elif k == "attention_mask" and use_fp16 and use_buffer_share: # Skip because FP16 model has GroupQueryAttention, uses buffer sharing, # and GQA supports a causal mask by default @@ -146,59 +267,55 @@ def convert_inputs_for_ort( else: ort_inputs[k] = v.detach().cpu().numpy() - # Enable past-present-share-buffer by using device memory directly + # Reshape kv caches if using past-present-share-buffer if use_buffer_share and device != "" and device != "cpu" and device_id > -1: - for k, v in ort_inputs.items(): - new_v = v - # Allocate new buffers with max_sequence_length for GQA - if "cache" in k or "past_key_values" in k: - # Copy v (BxSxPxH) into new_v (BxSxMxH) - batch_size, num_heads, _, head_size = v.shape - new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) - new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v - ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id) + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs -# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx -def get_msft_sample_inputs( - config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool -): - np_dtype = np.float16 if use_fp16 else np.float32 - head_size = config.hidden_size // config.num_attention_heads - max_seq_len = 2048 +def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int): + for k, v in ort_inputs.items(): + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = new_v + return ort_inputs - if not split_kv: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), - "k_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "v_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "pos": np.array(past_seq_len, dtype=np.int64), - } - else: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( - np.int32 - ), - "pos": np.array(past_seq_len, dtype=np.int64), - } - for i in range(config.num_hidden_layers): - ort_inputs.update( - { - f"k_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - f"v_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - } - ) - return ort_inputs +# Add IO bindings for execution providers +def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, device_id: int, kv_cache_ortvalues: dict): + use_fp16 = False + io_binding = model.io_binding() + + for k, v in ort_inputs.items(): + # Detect if model is in FP16 + if v.dtype == np.float16: + use_fp16 = True + + # Bind OrtValue inputs to device + if use_fp16 and ("cache" in k or "past_key_values" in k): + if k not in kv_cache_ortvalues: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + kv_cache_ortvalues[k] = v_device + else: + kv_cache_ortvalues[k].update_inplace(v) + io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k]) + else: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + + for output in model.get_outputs(): + name = output.name + if use_fp16 and ("out" in name or "present" in name): + # Bind present KV cache outputs to past KV cache inputs in order to buffer share + input_name = name.replace("out", "cache").replace("present", "past_key_values") + io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name]) + else: + io_binding.bind_output(name, device_type=device, device_id=device_id) + + return io_binding, kv_cache_ortvalues diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 4353d0606803d..c1c5d3c412f2a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -8,6 +8,7 @@ import torch from benchmark_helper import setup_logger from llama_inputs import ( + add_io_bindings, convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, get_sample_inputs, @@ -22,22 +23,24 @@ def get_sequence_lengths(args: argparse.Namespace): past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) - max_sequence_length = 2048 + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 return past_sequence_length, curr_sequence_length, max_sequence_length def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity batch_size = 2 - past_sequence_length, sequence_length, _ = get_sequence_lengths(args) + past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) if args.merged: inputs = get_merged_sample_with_past_kv_inputs( config, args.device, batch_size, - sequence_length, - past_sequence_length, + seq_len=sequence_length, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, use_fp16=args.use_fp16, return_dict=True, ) @@ -51,31 +54,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): return inputs -def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict): - # Add IO bindings for non-CPU execution providers - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.use_fp16: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.use_fp16 and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id)) - - return io_binding - - -def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): +def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM, kv_cache_ortvalues: dict): inputs = get_inputs(args, config) # Run inference with PyTorch @@ -111,7 +90,9 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": - io_binding = add_io_bindings(args, ort_model, inputs) + io_binding, kv_cache_ortvalues = add_io_bindings( + ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues + ) io_binding.synchronize_inputs() start_time = time.time() @@ -131,17 +112,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama logger.info(f"ONNX Runtime took {end_time - start_time} s") # Compare PyTorch and ONNX Runtime accuracy - tol = ( - 2e1 - if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path - else 1e-3 - if args.precision == "fp32" - else 5e-1 - ) + tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1 parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") if not parity: logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}") + return kv_cache_ortvalues def get_args(argv: List[str]): @@ -250,16 +226,17 @@ def main(argv: List[str] = []): # noqa: B006 use_cache=True, ).to(args.device) + kv_cache_ortvalues = {} if not args.merged: - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) else: # Verify prompt generation in merged model (decoder_model.onnx) args.use_past_kv = False - verify_parity(args, config, llama) + kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues) # Verify token generation in merged model (decoder_with_past_model.onnx) args.use_past_kv = True - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) if __name__ == "__main__": From c181159783d1245adb9bb1af18a469ad7d89df45 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 1 Nov 2023 11:30:32 +0800 Subject: [PATCH 16/21] [WebNN EP] Restore to use deviceType enum (#18154) The Chromium implementation will support `MLDeviceType` enum to align with spec. CL: https://chromium-review.googlesource.com/c/chromium/src/+/4986939 --- .../core/providers/webnn/webnn_execution_provider.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 26c739e9a1ce1..02a3d16b5b64f 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -26,11 +26,7 @@ WebNNExecutionProvider::WebNNExecutionProvider( ORT_THROW("Failed to get ml from navigator."); } emscripten::val context_options = emscripten::val::object(); - // Currently WebNN implementation in Chromium temporarily reuses the MLContextOptions - // defined in Model Loader API, which uses MLDevicePreference instead of MLDeviceType - // defined in WebNN. Because there's an ongoing spec discussion to simplify this API at - // https://github.com/webmachinelearning/webnn/issues/302. - context_options.set("devicePreference", emscripten::val(webnn_device_flags)); + context_options.set("deviceType", emscripten::val(webnn_device_flags)); // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; From 819b5a3eba85cca9276c9d763c814eb45067b280 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Tue, 31 Oct 2023 21:05:42 -0700 Subject: [PATCH 17/21] Split KV on MHA and Attention ops (#18007) ### Description Implement Split KV optimization for FlashAttention in MHA and Attention operators. ### Motivation and Context Can help further accelerate these ops. --- .../contrib_ops/cpu/bert/attention_common.h | 3 ++- .../contrib_ops/cuda/bert/attention.cc | 22 +++++++++++++++ .../contrib_ops/cuda/bert/attention_impl.cu | 4 ++- .../contrib_ops/cuda/bert/attention_impl.h | 5 ++++ .../cuda/bert/flash_attention/flash_api.cc | 27 ++++++++++++++++--- .../cuda/bert/flash_attention/flash_api.h | 6 ++--- .../flash_fwd_launch_template.h | 12 ++------- .../cuda/bert/group_query_attention.cc | 22 ++++++--------- .../cuda/bert/multihead_attention.cc | 22 +++++++++++++++ 9 files changed, 90 insertions(+), 33 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 5184dd99309b1..0fd8790e0d29d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -55,6 +55,7 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; + int num_splits; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -95,9 +96,9 @@ struct GroupQueryAttentionParameters { int head_size; int kv_hidden_size; int kv_num_heads; + int num_splits; // number of splits for splitkv bool is_unidirectional; // causal float scale; - int num_splits; // number of splits for splitkv AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 0dc7de0e9e519..bf6431cf1afb2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -135,8 +135,24 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif if (!use_flash_attention) { @@ -279,6 +295,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index b4a4ae208ceb1..eb9e6d5c62467 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -316,7 +316,9 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + true)); DUMP_TENSOR("flash attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d0a5fb51a25d6..3e78978c3cc43 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -88,6 +88,11 @@ struct AttentionData { T* v = nullptr; T* scratch = nullptr; AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index ff7a22d253a5b..89a27c4d2b0d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -140,11 +140,10 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, - int max_splits, bool new_kv, bool is_sm8x) { + int max_splits) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)); - const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n; + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (seqlen_q + 64 - 1) / 64; @@ -190,6 +189,26 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea return 1; } +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 0a0328edb0059..58f4304251872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -31,6 +31,7 @@ #if USE_FLASH_ATTENTION #include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace flash { @@ -99,10 +100,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, ); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); -size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q); -size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded); -int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x); +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 784335a124c75..82dfa59b8f8e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -123,17 +123,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; constexpr int kBlockM = 64; // Fixed for all head dimensions - if (!is_sm8x) { // A100, H100 - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); } template diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 65d19d4473872..67d750aeac11a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -116,22 +116,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { size_t out_accum_bytes = 0; size_t seqlens_k_bytes = 0; if (use_flash_attention) { + // softmax buffer softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); - // split kv buffers - parameters.num_splits = onnxruntime::flash::num_splits_heuristic( + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, - parameters.head_size, device_prop.multiProcessorCount, 128, false, - device_prop.major == 8 && device_prop.minor > 0); - if (parameters.num_splits > 1) { - // softmax_lse_accum buffer - softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); - // out_accum buffer - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(parameters.head_size, 32); - out_accum_bytes = onnxruntime::flash::get_out_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded); - } + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; // seqlens_k buffer if (past_key != nullptr) { seqlens_k_bytes = sizeof(int) * parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index e3f53ca6a63cb..ebd66d8c6528e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -153,8 +153,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif bool use_fused_cross_attention = !use_flash_attention && @@ -291,6 +307,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); From 69f029797d24e44c6854d9c68231bef174e627e4 Mon Sep 17 00:00:00 2001 From: weischan-quic <138087696+weischan-quic@users.noreply.github.com> Date: Wed, 1 Nov 2023 14:04:42 +0800 Subject: [PATCH 18/21] [QNN EP] Fix Batch Normalization Op Builder (#17981) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description There is a gap between onnx’s definition of batch normalization and QNN’s. According to the formula: onnx: `(X - input_mean) / sqrt(input_var + epsilon) * scale + B` QNN: `X * weight + bias` We can then deduce that: `weight = scale / sqrt(var + epsilon)` `bias = B – (mean * scale / sqrt(var + epsilon))` We must calculate the weight and bias, and their quantization parameters for QNN in QNN EP. Therefore, `scale`, `B`, `input_mean`, and `input_var` must be static (`initializer`). Implementation: Firstly, dequantize `scale`, `B`, `input_mean`, and `input_var` to floating point. Second, calculate `weight` and `bias`, and their quantization parameters. Finally, quantize `weight` and `bias`, and add them into `TensorWrapper` ### Motivation and Context Fix QnnHTPBackendTests.BatchNorm1D and QnnHTPBackendTests.BatchNorm2D failures --- .../opbuilder/batch_norm_op_builder.cc | 589 +++++++++++++++++- .../test/providers/qnn/batch_norm_htp_test.cc | 18 +- 2 files changed, 583 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index ccbc1acaa2f9e..3e17fb157b160 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -1,16 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include + #include "core/providers/common.h" +#include "core/util/qmath.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "base_op_builder.h" -#include - namespace onnxruntime { namespace qnn { class BatchNormOpBuilder : public BaseOpBuilder { @@ -18,9 +22,446 @@ class BatchNormOpBuilder : public BaseOpBuilder { BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BatchNormOpBuilder); + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + + std::pair CheckMinMax(float rmin, float rmax) const { + // Ensure a minimum range of 0.0001 (required by QNN) + rmax = std::max(rmax, rmin + 0.0001f); + + // Both QNN and ORT require the range to include 0.0f + rmin = std::min(rmin, 0.0f); + rmax = std::max(rmax, 0.0f); + + return std::make_pair(rmin, rmax); + } + + template + Status GetQminQmax(const Qnn_DataType_t qnn_data_type, + T& qmin, + T& qmax) const { + if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + Status GetQuantParams(float rmin, + float rmax, + const Qnn_DataType_t qnn_data_type, + float& scale, + int& zero_point) const { + std::tie(rmin, rmax) = CheckMinMax(rmin, rmax); + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + + scale = (rmax - rmin) / (qmax - qmin); + const float initial_zero_point = qmin - (rmin / scale); + zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); + // To match QNN quantization definition + zero_point = 0 - zero_point; + return Status::OK(); + } + + inline Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const uint8_t* raw_ptr, + double& value, + int& offset) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int8_t); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int16_t); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int32_t); + break; + } + case QNN_DATATYPE_INT_64: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int64_t); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint8_t); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint16_t); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint32_t); + break; + } + case QNN_DATATYPE_UINT_64: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint64_t); + break; + } + case QNN_DATATYPE_FLOAT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(float); + break; + } + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline Status AssertUnpackedTensorSize(const Qnn_DataType_t qnn_data_type, + const uint32_t channel, + const size_t raw_ptr_length) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int8_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int16_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int32_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_64: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int64_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint8_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint16_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint32_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_64: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint64_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_FLOAT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(float)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline Status ConvertToRawOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const std::vector& double_tensor, + std::vector& raw_tensor) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: { + raw_tensor.resize(double_tensor.size() * sizeof(int8_t)); + int8_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_16: { + raw_tensor.resize(double_tensor.size() * sizeof(int16_t)); + int16_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(int32_t)); + int32_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_64: { + raw_tensor.resize(double_tensor.size() * sizeof(int64_t)); + int64_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_8: { + raw_tensor.resize(double_tensor.size() * sizeof(uint8_t)); + uint8_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_16: { + raw_tensor.resize(double_tensor.size() * sizeof(uint16_t)); + uint16_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(uint32_t)); + uint32_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_64: { + raw_tensor.resize(double_tensor.size() * sizeof(uint64_t)); + uint64_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_FLOAT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(float)); + float* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UFIXED_POINT_32: + case QNN_DATATYPE_UFIXED_POINT_16: + case QNN_DATATYPE_UFIXED_POINT_8: + case QNN_DATATYPE_SFIXED_POINT_32: + case QNN_DATATYPE_SFIXED_POINT_16: + case QNN_DATATYPE_SFIXED_POINT_8: + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline double Dequantize(const OnnxInputInfo& info, + const double quant_value) const { + auto offset = static_cast(info.quant_param.scaleOffsetEncoding.offset); + auto scale = static_cast(info.quant_param.scaleOffsetEncoding.scale); + return (quant_value + offset) * scale; + } + + template + inline T Saturate(const T qmax, + const T qmin, + const T quant_value) const { + if (quant_value > qmax) { + return qmax; + } else if (quant_value < qmin) { + return qmin; + } else { + return quant_value; + } + } + + inline Status Quantize(const double double_value, + const float scale, + const int zero_point, + const Qnn_DataType_t qnn_data_type, + int& quant_value) const { + int qmin = 0; + int qmax = 255; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + quant_value = Saturate(qmax, qmin, static_cast(std::round((double_value / scale) - zero_point))); + return Status::OK(); + } + + Status PreprocessMean(const OnnxInputInfo& mean_info, + const bool is_npu_backend, + const uint8_t* mean_raw_ptr, + const size_t mean_raw_ptr_length, + std::vector& mean_out) const { + // tensor length (channel) + uint32_t channel = mean_info.shape[0]; + mean_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(mean_info.qnn_data_type, channel, mean_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double mean_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset)); + mean_out[i] = (is_npu_backend) ? Dequantize(mean_info, mean_value) : mean_value; + } + return Status::OK(); + } + + Status PreprocessStd(const OnnxInputInfo& var_info, + const bool is_npu_backend, + const uint8_t* var_raw_ptr, + const size_t var_raw_ptr_length, + const float epsilon, + std::vector& std_out) const { + // tensor length (channel) + uint32_t channel = var_info.shape[0]; + std_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(var_info.qnn_data_type, channel, var_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double var_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset)); + std_out[i] = (is_npu_backend) ? Dequantize(var_info, var_value) : var_value; + std_out[i] = std::sqrt(std_out[i] + static_cast(epsilon)); + } + return Status::OK(); + } + + Status PreprocessScale(const OnnxInputInfo& scale_info, + const bool is_npu_backend, + const uint8_t* scale_raw_ptr, + const size_t scale_raw_ptr_length, + const std::vector& std_double_tensor, + double& rmax, + double& rmin, + std::vector& scale_out) const { + // tensor length (channel) + uint32_t channel = scale_info.shape[0]; + scale_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(scale_info.qnn_data_type, channel, scale_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double scale_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset)); + scale_out[i] = (is_npu_backend) ? Dequantize(scale_info, scale_value) : scale_value; + scale_out[i] = scale_out[i] / std_double_tensor[i]; + rmax = std::max(rmax, scale_out[i]); + rmin = std::min(rmin, scale_out[i]); + } + return Status::OK(); + } + + Status PreprocessBias(const OnnxInputInfo& bias_info, + const bool is_npu_backend, + const uint8_t* bias_raw_ptr, + const size_t bias_raw_ptr_length, + const std::vector& scale_double_tensor, + const std::vector& mean_double_tensor, + double& rmax, + double& rmin, + std::vector& bias_out) const { + // tensor length (channel) + uint32_t channel = bias_info.shape[0]; + bias_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(bias_info.qnn_data_type, channel, bias_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double bias_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset)); + bias_out[i] = (is_npu_backend) ? Dequantize(bias_info, bias_value) : bias_value; + bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]); + rmax = std::max(rmax, bias_out[i]); + rmin = std::min(rmin, bias_out[i]); + } + return Status::OK(); + } + + Status Postprocess(const OnnxInputInfo& info, + const bool is_npu_backend, + const std::vector& double_tensor, + const double rmax, + const double rmin, + Qnn_QuantizeParams_t& quant_param, + std::vector& raw_tensor) const { + if (is_npu_backend) { + raw_tensor.resize(double_tensor.size()); + float scale = 0.0f; + int zero_point = 0; + ORT_RETURN_IF_ERROR(GetQuantParams(static_cast(rmin), + static_cast(rmax), + info.qnn_data_type, + scale, + zero_point)); + quant_param = QNN_QUANTIZE_PARAMS_INIT; + utils::InitializeQuantizeParam(quant_param, true, scale, zero_point); + for (size_t i = 0; i < double_tensor.size(); ++i) { + // onnx only supports 8 bits quantization + int quant_value_int = 0; + ORT_RETURN_IF_ERROR(Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int)); + if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + raw_tensor[i] = static_cast(quant_value_int); + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + int8_t quant_value = static_cast(quant_value_int); + raw_tensor[i] = *reinterpret_cast(&quant_value); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type); + } + } + } else { + ORT_RETURN_IF_ERROR(ConvertToRawOnQnnDataType(info.qnn_data_type, double_tensor, raw_tensor)); + } + return Status::OK(); + } }; // BatchNorm is sensitive with data layout, no special validation so far @@ -34,11 +475,6 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, // Still do it here so hopefully QNN Op validation API can tell us some details why it's not supported return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } else { - NodeAttrHelper node_helper(node_unit); - const float default_epsilon = 1e-05f; - const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. - ORT_RETURN_IF(abs(epsilon - default_epsilon) > default_epsilon, "QNN BatchNorm doesn't support epsilon."); - const auto& inputs = node_unit.Inputs(); ORT_ENFORCE(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec."); @@ -56,11 +492,16 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, std::vector scale_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[1].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic scale."); ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels, "QNN BatchNorm input 1 (scale) must have 1D shape [channel]."); std::vector bias_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[2].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic bias."); + ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels, "QNN BatchNorm input 2 (bias) must have 1D shape [channel]."); @@ -68,13 +509,15 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean)."); ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels, "QNN BatchNorm input 3 (mean) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic mean."); std::vector var_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var)."); ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels, "QNN BatchNorm input 4 (var) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic var."); ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output."); } @@ -82,6 +525,134 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + ORT_UNUSED_PARAMETER(logger); + + const auto& inputs = node_unit.Inputs(); + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + // + // Input 0 + // + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + + // + // Input 1: scale + // Input 2: bias + // QNN only accept 3 input. We need to first combine mean and variance into scale and bias. + // + { + const std::string& scale_name = inputs[1].node_arg.Name(); + const std::string& bias_name = inputs[2].node_arg.Name(); + OnnxInputInfo var_info = {}; + OnnxInputInfo mean_info = {}; + OnnxInputInfo scale_info = {}; + OnnxInputInfo bias_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], scale_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], bias_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[3], mean_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[4], var_info)); + + // scale, bias, mean, and var must be initializers + ORT_RETURN_IF_NOT(scale_info.is_initializer, "scale must be initializers"); + ORT_RETURN_IF_NOT(bias_info.is_initializer, "bias must be initializers"); + ORT_RETURN_IF_NOT(mean_info.is_initializer, "mean must be initializers"); + ORT_RETURN_IF_NOT(var_info.is_initializer, "var must be initializers"); + + std::vector scale_unpacked_tensor; + std::vector bias_unpacked_tensor; + std::vector var_unpacked_tensor; + std::vector mean_unpacked_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*scale_info.initializer_tensor, scale_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*bias_info.initializer_tensor, bias_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*mean_info.initializer_tensor, mean_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*var_info.initializer_tensor, var_unpacked_tensor)); + + std::vector mean_double_tensor; + std::vector std_double_tensor; + std::vector scale_double_tensor; + std::vector bias_double_tensor; + + NodeAttrHelper node_helper(node_unit); + const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. + + double scale_rmax = std::numeric_limits::min(); + double scale_rmin = std::numeric_limits::max(); + double bias_rmax = std::numeric_limits::min(); + double bias_rmin = std::numeric_limits::max(); + + // Calculate and convert new scale, new bias, mean and std to double array (may be dequantized) + ORT_RETURN_IF_ERROR(PreprocessMean(mean_info, + is_npu_backend, + mean_unpacked_tensor.data(), + mean_unpacked_tensor.size(), + mean_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessStd(var_info, + is_npu_backend, + var_unpacked_tensor.data(), + var_unpacked_tensor.size(), + epsilon, + std_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessScale(scale_info, + is_npu_backend, + scale_unpacked_tensor.data(), + scale_unpacked_tensor.size(), + std_double_tensor, + scale_rmax, + scale_rmin, + scale_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessBias(bias_info, + is_npu_backend, + bias_unpacked_tensor.data(), + bias_unpacked_tensor.size(), + scale_double_tensor, + mean_double_tensor, + bias_rmax, + bias_rmin, + bias_double_tensor)); + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(scale_name)) { + std::vector scale_raw_tensor; + Qnn_QuantizeParams_t scale_quant_param = scale_info.quant_param; + ORT_RETURN_IF_ERROR(Postprocess(scale_info, + is_npu_backend, + scale_double_tensor, + scale_rmax, + scale_rmin, + scale_quant_param, + scale_raw_tensor)); + Qnn_TensorType_t scale_tensor_type = GetInputTensorType(qnn_model_wrapper, scale_name); + QnnTensorWrapper input_tensorwrapper(scale_name, scale_tensor_type, scale_info.qnn_data_type, scale_quant_param, + std::move(scale_info.shape), std::move(scale_raw_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + input_names.push_back(scale_name); + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(bias_name)) { + std::vector bias_raw_tensor; + Qnn_QuantizeParams_t bias_quant_param = bias_info.quant_param; + ORT_RETURN_IF_ERROR(Postprocess(bias_info, + is_npu_backend, + bias_double_tensor, + bias_rmax, + bias_rmin, + bias_quant_param, + bias_raw_tensor)); + Qnn_TensorType_t bias_tensor_type = GetInputTensorType(qnn_model_wrapper, bias_name); + QnnTensorWrapper input_tensorwrapper(bias_name, bias_tensor_type, bias_info.qnn_data_type, bias_quant_param, + std::move(bias_info.shape), std::move(bias_raw_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + input_names.push_back(bias_name); + } + + return Status::OK(); +} + void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc index 9b65ca7bda3e2..b4e8f5390787c 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc @@ -175,13 +175,7 @@ static void RunBatchNormQDQTest(const TestInputDef& input_def, // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 3. -// QNN v2.13 -// Inaccuracy detected for output 'output', element 4. -// Output quant params: scale=0.019084848463535309, zero_point=9. -// Expected val: 1.7755576372146606 -// QNN QDQ val: 2.9963212013244629 (err 1.2207635641098022) -// CPU QDQ val: 0.82064849138259888 (err 0.95490914583206177) -TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm1D) { +TEST_F(QnnHTPBackendTests, BatchNorm1D) { constexpr int64_t num_channels = 2; RunBatchNormQDQTest(TestInputDef({1, num_channels, 3}, false, {-5.0f, -4.0f, -3.0f, 0.0f, 2.0f, 5.0f}), // Input data @@ -193,13 +187,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm1D) { // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 4. -// QNN v2.13 -// Inaccuracy detected for output 'output', element 14. -// Output quant params: scale=0.023071292787790298, zero_point=19. -// Expected val: 2.8554618358612061 -// QNN QDQ val: 5.3294687271118164 (err 2.4740068912506104) -// CPU QDQ val: 1.6611330509185791 (err 1.194328784942627) -TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm2D) { +TEST_F(QnnHTPBackendTests, BatchNorm2D) { constexpr int64_t num_channels = 2; std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; @@ -226,4 +214,4 @@ TEST_F(QnnHTPBackendTests, BatchNorm3D) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif From d87216bcb13c8a3937a74b1cd2160aeb7d9cffb7 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Wed, 1 Nov 2023 08:39:39 -0700 Subject: [PATCH 19/21] Openvino ep ort 23.1 (#17911) ### Description Integration to OpenVINO 2023.1 ### Motivation and Context - Alignment with latest OpenVINO Version. - Device name change from VPUX to NPU and Remove from supported list until official public support is available. --------- Co-authored-by: Sahar Fatima Co-authored-by: Saurabh Kale Co-authored-by: Suryaprakash Shanmugam Co-authored-by: sfatimar --- cmake/CMakeLists.txt | 18 - docs/python/ReadMeOV.rst | 2 - .../core/session/onnxruntime_c_api.h | 4 +- .../providers/openvino/backend_manager.cc | 24 +- .../core/providers/openvino/backend_manager.h | 13 +- .../core/providers/openvino/backend_utils.cc | 11 +- .../core/providers/openvino/backend_utils.h | 12 +- .../openvino/backends/backend_factory.cc | 2 +- .../openvino/backends/basic_backend.cc | 77 ++-- .../openvino/backends/basic_backend.h | 11 +- .../core/providers/openvino/contexts.h | 7 +- .../openvino/openvino_execution_provider.cc | 24 +- .../openvino/openvino_execution_provider.h | 50 ++- .../openvino/openvino_provider_factory.cc | 55 ++- .../core/providers/openvino/ov_interface.cc | 10 +- .../core/providers/openvino/ov_interface.h | 13 +- .../openvino/ov_versions/capabilities.h | 2 + .../openvino/ov_versions/capability.cc | 17 +- .../openvino/ov_versions/data_ops.cc | 419 +++++++++++------- .../providers/openvino/ov_versions/data_ops.h | 13 +- .../providers/openvino/ov_versions/utils.cc | 12 +- .../providers/openvino/ov_versions/utils.h | 21 +- .../core/session/provider_bridge_ort.cc | 2 +- .../python/onnxruntime_pybind_state.cc | 4 +- .../python/onnxruntime_pybind_state_common.h | 8 +- .../test/perftest/command_args_parser.cc | 4 +- onnxruntime/test/perftest/ort_test_session.cc | 11 +- .../test/providers/cpu/nn/lp_norm_op_test.cc | 4 +- .../test/providers/cpu/rnn/rnn_op_test.cc | 4 +- .../providers/cpu/tensor/compress_op.test.cc | 2 +- .../providers/cpu/tensor/unsqueeze_op_test.cc | 2 +- .../test/python/onnx_backend_test_series.py | 3 + .../onnx_backend_test_series_filters.jsonc | 4 + tools/ci_build/build.py | 13 +- .../nuget/generate_nuspec_for_native_nuget.py | 44 +- 35 files changed, 564 insertions(+), 358 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index f81a268d38dff..94181448fd21c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1282,14 +1282,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP16) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_U8) - add_definitions(-DOPENVINO_CONFIG_VPUX_U8=1) - endif() - if (onnxruntime_USE_OPENVINO_GPU_FP32_NP) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) @@ -1310,16 +1302,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP32_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP32=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_FP16_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - if (onnxruntime_USE_OPENVINO_HETERO) add_definitions(-DOPENVINO_CONFIG_HETERO=1) add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}") diff --git a/docs/python/ReadMeOV.rst b/docs/python/ReadMeOV.rst index f12c01d278dca..6ef16e1378139 100644 --- a/docs/python/ReadMeOV.rst +++ b/docs/python/ReadMeOV.rst @@ -7,7 +7,6 @@ OpenVINOâ„¢ Execution Provider for ONNX Runtime accelerates inference across man - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs Installation ------------ @@ -22,7 +21,6 @@ This package supports: - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs ``pip3 install onnxruntime-openvino`` diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 613c1ac93cf1b..729a302f3dd0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -611,7 +611,7 @@ typedef struct OrtMIGraphXProviderOptions { typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus OrtOpenVINOProviderOptions() : device_type{}, - enable_vpu_fast_compile{}, + enable_npu_fast_compile{}, device_id{}, num_of_threads{}, cache_dir{}, @@ -624,7 +624,7 @@ typedef struct OrtOpenVINOProviderOptions { * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" */ const char* device_type; - unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled + unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; size_t num_of_threads; ///< 0 = Use default number of threads const char* cache_dir; // path is set to empty by default diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 78467b646b195..7e4c0dc8d7267 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -2,9 +2,7 @@ // Licensed under the MIT License #include -#include -#include -#include +#include #include "core/providers/shared_library/provider_api.h" #include "contexts.h" @@ -18,7 +16,8 @@ namespace openvino_ep { static std::unique_ptr g_global_context; GlobalContext& BackendManager::GetGlobalContext() { - // This is not thread safe to call for the first time, but it is first called on the main thread by the constructor so it is safe. + // This is not thread safe to call for the first time, + // but it is first called on the main thread by the constructor so it is safe. if (!g_global_context) g_global_context = std::make_unique(); return *g_global_context; @@ -88,7 +87,9 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, << "Backend created for graph " << subgraph_context_.subgraph_name; } } else { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. Initializing backend for graph " << subgraph_context_.subgraph_name; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. " + << "Initializing backend for graph " + << subgraph_context_.subgraph_name; subgraph_context_.has_dynamic_input_shape = false; try { @@ -104,7 +105,7 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { bool has_batched_inputs = true; - for (int i = 0; i < (int)subgraph_context_.input_indexes.size(); i++) { + for (int i = 0; i < static_cast(subgraph_context_.input_indexes.size()); i++) { auto& input = model_proto.graph().input(subgraph_context_.input_indexes[i]); // Batch-process only raw image inputs (NCHW or NHWC layouts) @@ -215,7 +216,10 @@ BackendManager::ReWriteInputShapeInfo(const ONNX_NAMESPACE::ModelProto& model_pr auto graph_proto = model_copy->mutable_graph(); for (size_t i = 0, limit = input_shapes.size(); i < limit; i++) { - auto g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); + auto g_in_shape = graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); g_in_shape->clear_dim(); const auto& shape = input_shapes[i]; for (size_t dim = 0, end = shape.size(); dim < end; dim++) { @@ -234,7 +238,11 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p auto graph_proto = model_copy->mutable_graph(); for (int i = 0; i < graph_proto->input_size(); i++) { - ONNX_NAMESPACE::TensorShapeProto* g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); + ONNX_NAMESPACE::TensorShapeProto* g_in_shape = + graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); g_in_shape->mutable_dim(0)->clear_dim_value(); g_in_shape->mutable_dim(0)->set_dim_value(1); } diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index c247ab60d3a6f..a177324b23f7d 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -3,6 +3,11 @@ #pragma once +#include +#include +#include +#include + #include "ov_interface.h" #include "contexts.h" #include "ibackend.h" @@ -13,7 +18,9 @@ namespace openvino_ep { // Singleton class that manages all the backends class BackendManager { public: - BackendManager(const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger); + BackendManager(const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger); void Compute(OrtKernelContext* context); void ShutdownBackendManager(); static GlobalContext& GetGlobalContext(); @@ -21,7 +28,9 @@ class BackendManager { private: std::unique_ptr GetModelProtoFromFusedNode( - const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger) const; + const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger) const; bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index d49968cdb7f3d..d47c91dd46622 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -1,9 +1,7 @@ // Copyright (C) 2019-2022 Intel Corporation // Licensed under the MIT License -#include -#include -#include +#include #include #include @@ -58,7 +56,7 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext try { auto cnn_network = global_context.ie_core.ReadModel(model); if ((subgraph_context.precision == "FP16") && - (global_context.device_type.find("VPUX") == std::string::npos)) { + (global_context.device_type.find("NPU") == std::string::npos)) { // FP16 transformations ov::pass::ConvertFP32ToFP16 pass_obj; pass_obj.run_on_model(cnn_network); @@ -88,7 +86,8 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext size_t index = results.size() - 1; for (auto it = results.rbegin(); it != results.rend(); ++it) { - if (auto const_node = std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { + if (auto const_node = + std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { const_outputs_map[(*it)->get_friendly_name()] = const_node; results.erase(results.begin() + index); } @@ -254,7 +253,7 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName) { - long long totalTime = 0; + int64_t totalTime = 0; // Print performance counts stream << std::endl << "performance counts:" << std::endl diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index de78a150fe2dd..82b0351e87da5 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -4,9 +4,15 @@ #pragma once #define ORT_API_MANUAL_INIT +#include +#include +#include +#include +#include +#include + #include "core/session/onnxruntime_cxx_api.h" #include "contexts.h" -#include #include "ov_interface.h" #ifdef _WIN32 #include @@ -57,7 +63,9 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, size_t batch_slice_idx); std::shared_ptr -CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, const SubGraphContext& subgraph_context, +CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, + const GlobalContext& global_context, + const SubGraphContext& subgraph_context, std::map>& const_outputs_map); void printPerformanceCounts(const std::vector& performanceMap, diff --git a/onnxruntime/core/providers/openvino/backends/backend_factory.cc b/onnxruntime/core/providers/openvino/backends/backend_factory.cc index c339f24e7022f..c586dd8b38af9 100644 --- a/onnxruntime/core/providers/openvino/backends/backend_factory.cc +++ b/onnxruntime/core/providers/openvino/backends/backend_factory.cc @@ -16,7 +16,7 @@ BackendFactory::MakeBackend(const ONNX_NAMESPACE::ModelProto& model_proto, const SubGraphContext& subgraph_context) { std::string type = global_context.device_type; if (type == "CPU" || type.find("GPU") != std::string::npos || - type.find("VPUX") != std::string::npos || + type.find("NPU") != std::string::npos || type.find("HETERO") != std::string::npos || type.find("MULTI") != std::string::npos || type.find("AUTO") != std::string::npos) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index f9517d7942664..09e1322ff59fb 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -6,10 +6,10 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" -// #include #include "basic_backend.h" #include "../backend_manager.h" @@ -57,33 +57,39 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, cl_context ctx = static_cast(global_context_.context); remote_context_ = new ov::intel_gpu::ocl::ClContext(global_context_.ie_core.Get(), ctx); ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) if (!subgraph_context_.has_dynamic_input_shape && dev_prec != "CPU_FP16") { const std::string model = model_proto.SerializeAsString(); - exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + model, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; #endif #endif } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } } catch (const char* msg) { @@ -127,10 +133,10 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { } #endif #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - if (global_context_.device_type.find("VPUX") != std::string::npos) { + if (global_context_.device_type.find("NPU") != std::string::npos) { std::pair device_property; - device_property = std::make_pair("VPU_COMPILER_TYPE", "MLIR"); - device_config.emplace(ov::device::properties("VPUX", device_property)); + device_property = std::make_pair("NPU_COMPILER_TYPE", "DRIVER"); + device_config.emplace(ov::device::properties("NPU", device_property)); } #endif } @@ -152,12 +158,12 @@ void BasicBackend::EnableCaching() { } void BasicBackend::EnableGPUThrottling(ov::AnyMap& device_config) { - if (global_context_.enable_opencl_throttling == true && global_context_.device_type.find("GPU") != std::string::npos) { + if (global_context_.enable_opencl_throttling == true && + global_context_.device_type.find("GPU") != std::string::npos) { LOGS_DEFAULT(INFO) << log_tag << "Enabled OpenCL queue throttling for GPU device"; std::pair device_property; device_property = std::make_pair("PLUGIN_THROTTLE", "1"); device_config.emplace(ov::device::properties("GPU_CONFIG_KEY", device_property)); - // device_config[GPU_CONFIG_KEY(PLUGIN_THROTTLE)] = "1"; } } @@ -187,7 +193,9 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque if (input_names.find(onnx_input_name) != input_names.end()) { input_name = onnx_input_name; } else { - throw(log_tag + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + " doesn't exist in the list of OpenVINO input tensor names"); + throw(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); } size_t batch_slice_idx = 0; if (subgraph_context_.has_dynamic_input_shape && @@ -197,6 +205,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); auto tensor_shape = tensor_info.GetShape(); auto tensor_size = tensor_shape.size(); + const char* tensor_data = tensor.GetTensorData(); auto tensor_iter = 0; ov::Shape input_tensor_shape = ov::Shape(tensor_size, 0); for (auto i = tensor_shape.begin(); i != tensor_shape.end(); ++i) { @@ -204,8 +213,16 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque tensor_iter += 1; } auto input = ie_cnn_network_->get_parameters().at(input_idx); - OVTensorPtr tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); - FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + OVTensorPtr tensor_ptr; + // avoid input copies on the CPU device + if (global_context_.device_type.find("CPU") != std::string::npos) { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape, + (void*)tensor_data); + } else { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); + FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + } + try { infer_request->SetTensor(input_name, tensor_ptr); } catch (const char* msg) { @@ -251,7 +268,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe if (input_names.find(onnx_input_name) != input_names.end()) { input_name = onnx_input_name; } else { - throw(log_tag + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + " doesn't exist in the list of OpenVINO input tensor names"); + throw(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); } input_idx++; // Kernel Context Input Buffer @@ -264,9 +284,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe const cl::Buffer* shared_buffer_const = static_cast(tensor_data); // Create an Input Remote Blob auto input = ie_cnn_network_->get_parameters().at(0); - auto remote_blob = remote_context_->create_tensor(input->get_element_type(), input->get_shape(), *shared_buffer_const); - ov::Tensor tensor = static_cast(remote_blob); - OVTensorPtr tensor_ptr = std::make_shared(tensor); + auto remote_blob = remote_context_->create_tensor( + input->get_element_type(), input->get_shape(), *shared_buffer_const); + ov::Tensor tensor_remote = static_cast(remote_blob); + OVTensorPtr tensor_ptr = std::make_shared(tensor_remote); infer_request->SetTensor(input_name, tensor_ptr); } else { OVTensorPtr graph_input_blob; @@ -295,7 +316,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe } } if (!output_name_found) { - throw std::string(log_tag + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); + throw std::string( + log_tag + + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); } size_t batch_size = 1; @@ -307,9 +331,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe const cl::Buffer* shared_buffer_const = static_cast(tensor_data); // Create a shared Blob, set the Infer Request Output Blob auto output = ie_cnn_network_->get_results().at(0); - auto remote_tensor = remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); - ov::Tensor tensor = static_cast(remote_tensor); - OVTensorPtr tensor_ptr = std::make_shared(tensor); + auto remote_tensor = + remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); + ov::Tensor tensor_t = static_cast(remote_tensor); + OVTensorPtr tensor_ptr = std::make_shared(tensor_t); try { infer_request->SetTensor(output_name, tensor_ptr); } catch (const char* msg) { @@ -364,7 +389,8 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe throw(msg); } size_t batch_size = 1; - auto output_tensor = GetOutputTensor(context, batch_size, infer_request, output_name, subgraph_context_.output_names); + auto output_tensor = + GetOutputTensor(context, batch_size, infer_request, output_name, subgraph_context_.output_names); auto mem_info = output_tensor.GetTensorMemoryInfo(); if (mem_info.GetAllocatorName() == OpenVINO_GPU) { return; @@ -465,7 +491,8 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { #ifndef IO_BUFFER_ENABLED // Printing performance counts is disabled when IO_BUFFER_ENABLED if (openvino_ep::backend_utils::IsDebugEnabled()) { inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode - std::string& hw_target = (global_context_.device_id != "") ? global_context_.device_id : global_context_.device_type; + std::string& hw_target = + (global_context_.device_id != "") ? global_context_.device_id : global_context_.device_type; printPerformanceCounts(infer_request, std::cout, hw_target); } #endif diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 2f1d603640809..6eda641451a72 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -6,16 +6,17 @@ #include #define ORT_API_MANUAL_INIT -#include "core/session/onnxruntime_cxx_api.h" -#include "core/providers/openvino/contexts.h" -#include "core/providers/openvino/ibackend.h" -#include "core/providers/openvino/ov_interface.h" #include #include #include #include #include +#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/openvino/contexts.h" +#include "core/providers/openvino/ibackend.h" +#include "core/providers/openvino/ov_interface.h" + namespace onnxruntime { namespace openvino_ep { @@ -29,7 +30,7 @@ class BasicBackend : public IBackend { void Infer(OrtKernelContext* context) override; private: - bool ImportBlob(std::string hw_target, bool vpu_status); + bool ImportBlob(std::string hw_target, bool npu_status); void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index b61dcf8ca4922..29233e72c33b9 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -3,6 +3,9 @@ #pragma once +#include +#include +#include #include "ov_interface.h" namespace onnxruntime { @@ -12,7 +15,7 @@ namespace openvino_ep { struct GlobalContext { OVCore ie_core; bool is_wholly_supported_graph = false; - bool enable_vpu_fast_compile = false; + bool enable_npu_fast_compile = false; bool enable_opencl_throttling = false; bool enable_dynamic_shapes = false; size_t num_of_threads; @@ -34,7 +37,7 @@ struct GlobalContext { struct SubGraphContext { bool has_dynamic_input_shape = false; bool enable_batching = false; - bool set_vpu_config = false; + bool set_npu_config = false; bool is_constant = false; void* context = 0; std::string subgraph_name; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 990809926299e..a4c6b0f851c04 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -17,17 +17,18 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv openvino_ep::BackendManager::GetGlobalContext().device_type = info.device_type_; openvino_ep::BackendManager::GetGlobalContext().precision_str = info.precision_; - openvino_ep::BackendManager::GetGlobalContext().enable_vpu_fast_compile = info.enable_vpu_fast_compile_; + openvino_ep::BackendManager::GetGlobalContext().enable_npu_fast_compile = info.enable_npu_fast_compile_; openvino_ep::BackendManager::GetGlobalContext().cache_dir = info.cache_dir_; openvino_ep::BackendManager::GetGlobalContext().num_streams = info.num_streams_; openvino_ep::BackendManager::GetGlobalContext().context = info.context_; openvino_ep::BackendManager::GetGlobalContext().enable_opencl_throttling = info.enable_opencl_throttling_; openvino_ep::BackendManager::GetGlobalContext().enable_dynamic_shapes = info.enable_dynamic_shapes_; - if ((int)info.num_of_threads_ <= 0) { + if (static_cast(info.num_of_threads_) <= 0) { openvino_ep::BackendManager::GetGlobalContext().num_of_threads = 8; - } else if ((int)info.num_of_threads_ > 8) { - std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; + } else if (static_cast(info.num_of_threads_) > 8) { + std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + + std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; ORT_THROW(err_msg); } else { openvino_ep::BackendManager::GetGlobalContext().num_of_threads = info.num_of_threads_; @@ -56,7 +57,8 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv device_found = true; break; } - if (info.device_type_.find("VPUX") != std::string::npos && (info.precision_ == "FP16" || info.precision_ == "U8")) { + if ((info.device_type_.find("NPU") != std::string::npos) && + (info.precision_ == "FP16" || info.precision_ == "U8")) { device_found = true; break; } @@ -109,11 +111,14 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, openvino_ep::BackendManager::GetGlobalContext().onnx_model_name = graph_viewer.Name(); #ifdef _WIN32 std::wstring onnx_path = graph_viewer.ModelPath().ToPathString(); - openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = std::string(onnx_path.begin(), onnx_path.end()); + openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = + std::string(onnx_path.begin(), onnx_path.end()); #else - openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = graph_viewer.ModelPath().ToPathString(); + openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = + graph_viewer.ModelPath().ToPathString(); #endif - openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = graph_viewer.DomainToVersionMap().at(kOnnxDomain); + openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = + graph_viewer.DomainToVersionMap().at(kOnnxDomain); #if defined(OPENVINO_2022_1) openvino_ep::GetCapability obj(graph_viewer, @@ -151,7 +156,8 @@ common::Status OpenVINOExecutionProvider::Compile( openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true; - std::shared_ptr backend_manager = std::make_shared(fused_node, graph_body_viewer, *GetLogger()); + std::shared_ptr backend_manager = + std::make_shared(fused_node, graph_body_viewer, *GetLogger()); compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index a4fc09362fa23..3b56b54410e40 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -3,19 +3,28 @@ #pragma once -#include "backend_manager.h" #include #include #include +#include +#include +#include + +#include "backend_manager.h" namespace onnxruntime { static void print_build_options() { std::cout << "[ERROR] INVALID DEVICE BUILD TYPE SPECIFIED" << std::endl; - std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority you want to build" << std::endl; - std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build "; - std::cout << "are ['CPU','GPU','VPUX']" << std::endl; - std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" << std::endl; + std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority " + << "you want to build" + << std::endl; + std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build " + << "are ['CPU','GPU']" + << std::endl; + std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. " + << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" + << std::endl; } static std::vector split(const std::string& s, char delim) { @@ -39,7 +48,7 @@ static std::vector parseDevices(const std::string& device_string) { print_build_options(); ORT_THROW("Invalid device string: " + device_string); } - std::vector dev_options = {"CPU", "GPU", "VPUX"}; + std::vector dev_options = {"CPU", "GPU"}; for (std::string dev : devices) { if (!std::count(dev_options.begin(), dev_options.end(), dev)) { print_build_options(); @@ -53,7 +62,7 @@ static std::vector parseDevices(const std::string& device_string) { struct OpenVINOExecutionProviderInfo { std::string device_type_; std::string precision_; - bool enable_vpu_fast_compile_; + bool enable_npu_fast_compile_; std::string device_id_; size_t num_of_threads_; std::string cache_dir_; @@ -62,11 +71,18 @@ struct OpenVINOExecutionProviderInfo { bool enable_opencl_throttling_; bool enable_dynamic_shapes_; - explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_vpu_fast_compile, std::string dev_id, + explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_npu_fast_compile, std::string dev_id, size_t num_of_threads, std::string cache_dir, int num_streams, void* context, bool enable_opencl_throttling, bool enable_dynamic_shapes) - : enable_vpu_fast_compile_(enable_vpu_fast_compile), device_id_(dev_id), num_of_threads_(num_of_threads), cache_dir_(cache_dir), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), enable_dynamic_shapes_(enable_dynamic_shapes) { + : enable_npu_fast_compile_(enable_npu_fast_compile), + device_id_(dev_id), + num_of_threads_(num_of_threads), + cache_dir_(cache_dir), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + enable_dynamic_shapes_(enable_dynamic_shapes) { if (dev_type == "") { LOGS_DEFAULT(INFO) << "[OpenVINO-EP]" << "No runtime device selection option provided."; @@ -82,11 +98,11 @@ struct OpenVINOExecutionProviderInfo { #elif defined OPENVINO_CONFIG_GPU_FP16 device_type_ = "GPU"; precision_ = "FP16"; -#elif defined OPENVINO_CONFIG_VPUX_FP16 - device_type_ = "VPUX"; +#elif defined OPENVINO_CONFIG_NPU_FP16 + device_type_ = "NPU"; precision_ = "FP16"; -#elif defined OPENVINO_CONFIG_VPUX_U8 - device_type_ = "VPUX"; +#elif defined OPENVINO_CONFIG_NPU_U8 + device_type_ = "NPU"; precision_ = "U8"; #elif defined OPENVINO_CONFIG_HETERO || defined OPENVINO_CONFIG_MULTI || defined OPENVINO_CONFIG_AUTO #ifdef DEVICE_NAME @@ -126,11 +142,11 @@ struct OpenVINOExecutionProviderInfo { } else if (dev_type == "GPU.1_FP16") { device_type_ = "GPU.1"; precision_ = "FP16"; - } else if (dev_type == "VPUX_FP16") { - device_type_ = "VPUX"; + } else if (dev_type == "NPU_FP16") { + device_type_ = "NPU"; precision_ = "FP16"; - } else if (dev_type == "VPUX_U8") { - device_type_ = "VPUX"; + } else if (dev_type == "NPU_U8") { + device_type_ = "NPU"; precision_ = "U8"; } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0) { std::vector devices = parseDevices(dev_type); diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 95b39bcc05983..fbb89710c8008 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -8,11 +8,16 @@ namespace onnxruntime { struct OpenVINOProviderFactory : IExecutionProviderFactory { - OpenVINOProviderFactory(const char* device_type, bool enable_vpu_fast_compile, + OpenVINOProviderFactory(const char* device_type, bool enable_npu_fast_compile, const char* device_id, size_t num_of_threads, const char* cache_dir, int num_streams, void* context, bool enable_opencl_throttling, bool enable_dynamic_shapes) - : enable_vpu_fast_compile_(enable_vpu_fast_compile), num_of_threads_(num_of_threads), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), enable_dynamic_shapes_(enable_dynamic_shapes) { + : enable_npu_fast_compile_(enable_npu_fast_compile), + num_of_threads_(num_of_threads), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + enable_dynamic_shapes_(enable_dynamic_shapes) { device_type_ = (device_type == nullptr) ? "" : device_type; device_id_ = (device_id == nullptr) ? "" : device_id; cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir; @@ -24,7 +29,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { private: std::string device_type_; - bool enable_vpu_fast_compile_; + bool enable_npu_fast_compile_; std::string device_id_; size_t num_of_threads_; std::string cache_dir_; @@ -35,7 +40,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { - OpenVINOExecutionProviderInfo info(device_type_, enable_vpu_fast_compile_, device_id_, num_of_threads_, + OpenVINOExecutionProviderInfo info(device_type_, enable_npu_fast_compile_, device_id_, num_of_threads_, cache_dir_, num_streams_, context_, enable_opencl_throttling_, enable_dynamic_shapes_); return std::make_unique(info); @@ -59,17 +64,18 @@ struct OpenVINO_Provider : Provider { std::string device_type = ""; // [device_type]: Overrides the accelerator hardware type and precision // with these values at runtime. - bool enable_vpu_fast_compile = false; // [enable_vpu_fast_compile]: Fast-compile may be optionally enabled to - // speeds up the model's compilation to VPU device specific format. + bool enable_npu_fast_compile = false; // [enable_npu_fast_compile]: Fast-compile may be optionally enabled to + // speeds up the model's compilation to NPU device specific format. const char* device_id = ""; // [device_id]: Selects a particular hardware device for inference. - size_t num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of + int num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of // threads with this value at runtime. const char* cache_dir = ""; // [cache_dir]: specify the path to // dump and load the blobs for the model caching/kernel caching (GPU) // feature. If blob files are already present, it will be directly loaded. int num_streams = 1; // [num_streams]: Option that specifies the number of parallel inference // requests to be processed on a given `device_type`. Overrides the - // accelerator default value of number of streams with this value at runtime. + // accelerator default value of number of streams + // with this value at runtime. bool enable_opencl_throttling = false; // [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU // device (Reduces CPU Utilization when using GPU) bool enable_dynamic_shapes = false; // [enable_dynamic_shapes]: Enables Dynamic Shapes feature for CPU device) @@ -80,14 +86,15 @@ struct OpenVINO_Provider : Provider { std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16", - "VPUX_FP16", "VPUX_U8"}; + "GPU.0_FP16", "GPU.1_FP16"}; if (!((ov_supported_device_types.find(device_type) != ov_supported_device_types.end()) || - (device_type.find("HETERO:") == 0) || (device_type.find("MULTI:") == 0) || (device_type.find("AUTO:") == 0))) { + (device_type.find("HETERO:") == 0) || + (device_type.find("MULTI:") == 0) || + (device_type.find("AUTO:") == 0))) { ORT_THROW( "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " - "'GPU.0_FP16', 'GPU.1_FP16', 'VPUX_FP16', 'VPUX_U8' or from" + "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); } } @@ -97,30 +104,37 @@ struct OpenVINO_Provider : Provider { if (provider_options_map.find("cache_dir") != provider_options_map.end()) { cache_dir = provider_options_map.at("cache_dir").c_str(); } + if (provider_options_map.find("context") != provider_options_map.end()) { - context = (void*)provider_options_map.at("context").c_str(); + std::string str = provider_options_map.at("context"); + uint64_t number = std::strtoull(str.c_str(), nullptr, 16); + context = reinterpret_cast(number); } if (provider_options_map.find("num_of_threads") != provider_options_map.end()) { num_of_threads = std::stoi(provider_options_map.at("num_of_threads")); if (num_of_threads <= 0) { num_of_threads = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_threads' should be in the positive range.\n " + << "Executing with num_threads=1"; } } if (provider_options_map.find("num_streams") != provider_options_map.end()) { num_streams = std::stoi(provider_options_map.at("num_streams")); - if (num_streams <= 0 && num_streams > 8) { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'num_streams' should be in the range of 1-8 \n"); + if (num_streams <= 0) { + num_streams = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_streams' should be in the range of 1-8.\n " + << "Executing with num_streams=1"; } } std::string bool_flag = ""; - if (provider_options_map.find("enable_vpu_fast_compile") != provider_options_map.end()) { - bool_flag = provider_options_map.at("enable_vpu_fast_compile"); + if (provider_options_map.find("enable_npu_fast_compile") != provider_options_map.end()) { + bool_flag = provider_options_map.at("enable_npu_fast_compile"); if (bool_flag == "true" || bool_flag == "True") - enable_vpu_fast_compile = true; + enable_npu_fast_compile = true; else if (bool_flag == "false" || bool_flag == "False") - enable_vpu_fast_compile = false; + enable_npu_fast_compile = false; bool_flag = ""; } @@ -141,7 +155,7 @@ struct OpenVINO_Provider : Provider { enable_dynamic_shapes = false; } return std::make_shared(const_cast(device_type.c_str()), - enable_vpu_fast_compile, + enable_npu_fast_compile, device_id, num_of_threads, cache_dir, @@ -157,7 +171,6 @@ struct OpenVINO_Provider : Provider { void Shutdown() override { openvino_ep::BackendManager::ReleaseGlobalContext(); } - } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 3914488fc523b..d2ce378c97e02 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -29,7 +29,10 @@ std::shared_ptr OVCore::ReadModel(const std::string& model) const { } } -OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name) { +OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name) { ov::CompiledModel obj; try { obj = oe.compile_model(ie_cnn_network, hw_target, device_config); @@ -43,7 +46,10 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std } #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) -OVExeNetwork OVCore::LoadNetwork(const std::string& model, std::string& hw_target, ov::AnyMap& device_config, std::string name) { +OVExeNetwork OVCore::LoadNetwork(const std::string& model, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name) { ov::CompiledModel obj; try { obj = oe.compile_model(model, ov::Tensor(), hw_target, device_config); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index ed9583033ab34..935ac8f68411d 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -4,6 +4,7 @@ #pragma once #include +#include #if defined(OPENVINO_2022_1) || (OPENVINO_2022_2) || (OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) #define OV_API_20 @@ -43,9 +44,15 @@ class OVCore { public: std::shared_ptr ReadModel(const std::string& model_stream) const; - OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name); + OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name); #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - OVExeNetwork LoadNetwork(const std::string& model_stream, std::string& hw_target, ov::AnyMap& device_config, std::string name); + OVExeNetwork LoadNetwork(const std::string& model_stream, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name); #endif void SetCache(std::string cache_dir_path); #ifdef IO_BUFFER_ENABLED @@ -62,7 +69,7 @@ class OVExeNetwork { ov::CompiledModel obj; public: - OVExeNetwork(ov::CompiledModel md) { obj = md; } + explicit OVExeNetwork(ov::CompiledModel md) { obj = md; } OVExeNetwork() { obj = ov::CompiledModel(); } ov::CompiledModel& Get() { return obj; } OVInferRequest CreateInferRequest(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/capabilities.h b/onnxruntime/core/providers/openvino/ov_versions/capabilities.h index b76d1cf534c2a..5bcf9d68cd94e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capabilities.h +++ b/onnxruntime/core/providers/openvino/ov_versions/capabilities.h @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include "data_ops.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 171dd45c508cc..b030efa238209 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -24,7 +24,8 @@ namespace openvino_ep { // Constructor GetCapability::GetCapability(const GraphViewer& graph_viewer_param, std::string device_type_param, - const std::string version_param) : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { + const std::string version_param) + : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { if (version_param == "V_2022_1") { data_ops_ = new DataOps(graph_viewer_, V_2022_1, device_type_); } else if (version_param == "V_2022_2") { @@ -114,11 +115,11 @@ std::vector> GetCapability::Execute() { } openvino_ep::BackendManager::GetGlobalContext().is_wholly_supported_graph = true; - } else { // unsupported_nodes_idx.empty() - + } else { // unsupported_nodes_idx.empty() #if defined(OPENVINO_DISABLE_GRAPH_PARTITION) // disables graph partition at build time LOGS_DEFAULT(INFO) << "[OpenVINO-EP] DISABLE_GRAPH_PARTITION option is set"; - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, so making the full model fall back to default CPU Execution Provider"; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, " + << "so making the full model fall back to default CPU Execution Provider"; return result; #endif @@ -159,7 +160,13 @@ std::vector> GetCapability::Execute() { std::vector cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs; - GetInputsOutputsOfCluster(graph_viewer_, this_cluster, ng_required_initializers, cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs); + GetInputsOutputsOfCluster(graph_viewer_, + this_cluster, + ng_required_initializers, + cluster_graph_inputs, + cluster_inputs, + const_inputs, + cluster_outputs); bool omit_subgraph = false; // Omitting zero dim subgraphs diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 70118c94f9ff8..a5a0faa3a8f24 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -2,11 +2,15 @@ // Licensed under the MIT License #include +#include +#include +#include +#include +#include + #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" #include "../backend_manager.h" -#include -#include #include "data_ops.h" #include "capabilities.h" #include "utils.h" @@ -72,269 +76,355 @@ std::set ops_supported_as_function = { std::vector supported_op_mode = { {"Abs", V_2020_4, {"CPU", "GPU"}}, - {"Abs", V_2023_0, {"VPUX"}}, + {"Abs", V_2023_0, {"NPU"}}, {"Acos", V_2020_4, {"CPU"}}, {"Acos", V_2022_1, {"GPU"}}, + {"Acos", V_2023_1, {"NPU"}}, {"Acosh", V_2020_4, {"CPU"}}, {"Acosh", V_2022_1, {"GPU"}}, + {"Acosh", V_2023_1, {"NPU"}}, {"Add", V_2020_4, {"CPU", "GPU"}}, - {"Add", V_2023_0, {"VPUX"}}, + {"Add", V_2023_0, {"NPU"}}, {"And", V_2020_4, {"CPU", "GPU"}}, + {"And", V_2023_1, {"NPU"}}, {"ArgMax", V_2020_4, {"CPU"}}, {"ArgMax", V_2021_1, {"GPU"}}, {"ArgMin", V_2020_4, {"CPU"}}, {"ArgMin", V_2022_1, {"GPU"}}, {"Asin", V_2020_4, {"CPU", "GPU"}}, + {"Asin", V_2023_1, {"NPU"}}, {"Asinh", V_2020_4, {"CPU", "GPU"}}, + {"Asinh", V_2023_1, {"NPU"}}, {"Atan", V_2020_4, {"CPU", "GPU"}}, + {"Atan", V_2023_1, {"NPU"}}, {"Atanh", V_2020_4, {"CPU"}}, {"Atanh", V_2022_1, {"GPU"}}, + {"Atanh", V_2023_1, {"NPU"}}, {"AveragePool", V_2020_4, {"CPU", "GPU"}}, - {"AveragePool", V_2023_0, {"VPUX"}}, + {"AveragePool", V_2023_0, {"NPU"}}, {"BatchNormalization", V_2020_4, {"CPU", "GPU"}}, - {"BatchNormalization", V_2023_0, {"VPUX"}}, + {"BatchNormalization", V_2023_0, {"NPU"}}, {"BitShift", V_2022_1, {"CPU"}}, + {"BitShift", V_2023_1, {"NPU"}}, {"Cast", V_2020_4, {"CPU", "GPU"}}, - {"Cast", V_2023_0, {"VPUX"}}, + {"Cast", V_2023_0, {"NPU"}}, + {"CastLike", V_2023_1, {"CPU", "GPU", "NPU"}}, {"Ceil", V_2020_4, {"GPU"}}, {"Ceil", V_2021_4, {"CPU"}}, + {"Ceil", V_2023_1, {"NPU"}}, {"Celu", V_2022_1, {"CPU", "GPU"}}, {"Clip", V_2020_4, {"CPU", "GPU"}}, - {"Clip", V_2023_0, {"VPUX"}}, + {"Clip", V_2023_0, {"NPU"}}, + {"Compress", V_2023_1, {"CPU", "GPU"}}, {"Concat", V_2020_4, {"CPU", "GPU"}}, - {"Concat", V_2023_0, {"VPUX"}}, + {"Concat", V_2023_0, {"NPU"}}, {"Constant", V_2020_4, {"CPU", "GPU"}}, - {"Constant", V_2023_0, {"VPUX"}}, + {"Constant", V_2023_0, {"NPU"}}, {"ConstantOfShape", V_2020_4, {"CPU", "GPU"}}, - {"ConstantOfShape", V_2023_0, {"VPUX"}}, // Gets mapped to broadcast op in the plugin. + {"ConstantOfShape", V_2023_0, {"NPU"}}, // Gets mapped to broadcast op in the plugin. {"Conv", V_2020_4, {"CPU", "GPU"}}, - {"Conv", V_2023_0, {"VPUX"}}, + {"Conv", V_2023_0, {"NPU"}}, {"ConvInteger", V_2022_1, {"CPU", "GPU"}}, + {"ConvInteger", V_2023_1, {"NPU"}}, {"ConvTranspose", V_2020_4, {"CPU", "GPU"}}, + {"ConvTranspose", V_2023_1, {"NPU"}}, {"Cos", V_2020_4, {"CPU"}}, {"Cos", V_2022_1, {"GPU"}}, - {"Cos", V_2023_0, {"VPUX"}}, + {"Cos", V_2023_0, {"NPU"}}, {"Cosh", V_2020_4, {"CPU"}}, {"Cosh", V_2022_1, {"GPU"}}, + {"Cosh", V_2023_1, {"NPU"}}, {"CumSum", V_2022_1, {"CPU", "GPU"}}, - {"CumSum", V_2023_0, {"VPUX"}}, + {"CumSum", V_2023_0, {"NPU"}}, {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, - {"DepthToSpace", V_2023_0, {"VPUX"}}, + {"DepthToSpace", V_2023_0, {"NPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, - {"DequantizeLinear", V_2023_0, {"VPUX"}}, + {"DequantizeLinear", V_2023_0, {"NPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, - {"Div", V_2023_0, {"VPUX"}}, + {"Div", V_2023_0, {"NPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, - {"Dropout", V_2023_0, {"VPUX"}}, + {"Dropout", V_2023_0, {"NPU"}}, {"Elu", V_2020_4, {"CPU", "GPU"}}, - {"Elu", V_2023_0, {"VPUX"}}, + {"Elu", V_2023_0, {"NPU"}}, // {"Einsum", V_2023_0, {"CPU", "GPU"}}, {"Equal", V_2020_4, {"CPU", "GPU"}}, - {"Equal", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Equal", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Erf", V_2020_4, {"CPU", "GPU"}}, - {"Erf", V_2023_0, {"VPUX"}}, + {"Erf", V_2023_0, {"NPU"}}, {"Exp", V_2020_4, {"CPU", "GPU"}}, - {"Exp", V_2023_0, {"VPUX"}}, + {"Exp", V_2023_0, {"NPU"}}, {"Expand", V_2022_1, {"CPU", "GPU"}}, - {"Expand", V_2023_0, {"VPUX"}}, // Gets mapped to broadcast op and multiply op in the plugin. + {"Expand", V_2023_0, {"NPU"}}, // Gets mapped to broadcast op and multiply op in the plugin. {"EyeLike", V_2022_1, {"CPU"}}, - {"EyeLike", V_2023_0, {"VPUX"}}, // NoOP + {"EyeLike", V_2023_0, {"NPU"}}, // NoOP {"Flatten", V_2020_4, {"CPU", "GPU"}}, - {"Flatten", V_2023_0, {"VPUX"}}, + {"Flatten", V_2023_0, {"NPU"}}, {"Floor", V_2020_4, {"CPU", "GPU"}}, + {"Floor", V_2023_1, {"NPU"}}, {"Gather", V_2020_4, {"CPU", "GPU"}}, - {"Gather", V_2023_0, {"VPUX"}}, + {"Gather", V_2023_0, {"NPU"}}, {"GatherElements", V_2022_2, {"CPU", "GPU"}}, + {"GatherElements", V_2023_1, {"NPU"}}, {"GatherND", V_2021_4, {"CPU", "GPU"}}, + {"GatherND", V_2023_1, {"NPU"}}, {"Gemm", V_2020_4, {"CPU", "GPU"}}, - {"Gemm", V_2023_0, {"VPUX"}}, + {"Gemm", V_2023_0, {"NPU"}}, {"GlobalAveragePool", V_2020_4, {"CPU", "GPU"}}, - {"GlobalAveragePool", V_2023_0, {"VPUX"}}, + {"GlobalAveragePool", V_2023_0, {"NPU"}}, {"GlobalLpPool", V_2020_4, {"CPU", "GPU"}}, + {"GlobalLpPool", V_2023_1, {"NPU"}}, {"GlobalMaxPool", V_2022_1, {"CPU", "GPU"}}, + {"GlobalMaxPool", V_2023_1, {"NPU"}}, {"Greater", V_2020_4, {"CPU", "GPU"}}, - {"Greater", V_2023_0, {"VPUX"}}, + {"Greater", V_2023_0, {"NPU"}}, {"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}}, - {"GreaterOrEqual", V_2023_0, {"VPUX"}}, + {"GreaterOrEqual", V_2023_0, {"NPU"}}, {"GridSample", V_2022_3, {"CPU"}}, {"GridSample", V_2023_0, {"GPU"}}, + {"GridSample", V_2023_1, {"NPU"}}, + {"HardMax", V_2023_1, {"CPU", "GPU", "NPU"}}, {"Identity", V_2020_4, {"CPU", "GPU"}}, - {"Identity", V_2023_0, {"VPUX"}}, // NoOP + {"Identity", V_2023_0, {"NPU"}}, // NoOP {"If", V_2022_3, {"CPU", "GPU"}}, + {"If", V_2023_1, {"NPU"}}, {"ImageScaler", V_2022_1, {"CPU", "GPU"}}, - {"ImageScaler", V_2023_0, {"VPUX"}}, + {"ImageScaler", V_2023_0, {"NPU"}}, {"InstanceNormalization", V_2020_4, {"CPU", "GPU"}}, - {"InstanceNormalization", V_2023_0, {"VPUX"}}, + {"InstanceNormalization", V_2023_0, {"NPU"}}, {"HardSigmoid", V_2020_4, {"CPU", "GPU"}}, + {"HardSigmoid", V_2023_1, {"NPU"}}, {"HardMax", V_2022_1, {"CPU", "GPU"}}, {"LeakyRelu", V_2020_4, {"CPU", "GPU"}}, - {"LeakyRelu", V_2023_0, {"VPUX"}}, + {"LeakyRelu", V_2023_0, {"NPU"}}, {"Less", V_2020_4, {"CPU", "GPU"}}, - {"Less", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Less", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"LessOrEqual", V_2022_1, {"CPU", "GPU"}}, - {"LessOrEqual", V_2023_0, {"VPUX"}}, + {"LessOrEqual", V_2023_0, {"NPU"}}, {"Log", V_2020_4, {"CPU", "GPU"}}, - {"Log", V_2023_0, {"VPUX"}}, + {"Log", V_2023_0, {"NPU"}}, {"LogSoftMax", V_2022_1, {"CPU", "GPU"}}, {"Loop", V_2021_4, {"CPU", "GPU"}}, + {"LpNormalization", V_2023_1, {"CPU", "GPU", "NPU"}}, + {"LpPool", V_2023_1, {"CPU", "GPU", "NPU"}}, {"LRN", V_2020_4, {"CPU", "GPU"}}, - {"LRN", V_2023_0, {"VPUX"}}, + {"LRN", V_2023_0, {"NPU"}}, {"LSTM", V_2020_4, {"CPU", "GPU"}}, + {"LSTM", V_2023_1, {"NPU"}}, {"MatMul", V_2020_4, {"CPU", "GPU"}}, - {"MatMul", V_2023_0, {"VPUX"}}, + {"MatMul", V_2023_0, {"NPU"}}, {"MatMulInteger", V_2022_1, {"CPU"}}, + {"MatMulInteger", V_2023_1, {"NPU"}}, {"Max", V_2020_4, {"CPU", "GPU"}}, - {"Max", V_2023_0, {"VPUX"}}, + {"Max", V_2023_0, {"NPU"}}, {"MaxPool", V_2020_4, {"CPU", "GPU"}}, - {"MaxPool", V_2023_0, {"VPUX"}}, + {"MaxPool", V_2023_0, {"NPU"}}, {"Mean", V_2020_4, {"CPU", "GPU"}}, - {"Mean", V_2023_0, {"VPUX"}}, + {"Mean", V_2023_0, {"NPU"}}, {"MeanVarianceNormalization", V_2022_1, {"CPU", "GPU"}}, + {"MeanVarianceNormalization", V_2023_1, {"NPU"}}, {"Min", V_2020_4, {"CPU", "GPU"}}, - {"Min", V_2023_0, {"VPUX"}}, + {"Min", V_2023_0, {"NPU"}}, {"Mod", V_2022_1, {"CPU", "GPU"}}, {"Mul", V_2020_4, {"CPU", "GPU"}}, - {"Mul", V_2023_0, {"VPUX"}}, + {"Mul", V_2023_0, {"NPU"}}, {"Neg", V_2020_4, {"CPU", "GPU"}}, - {"Neg", V_2023_0, {"VPUX"}}, + {"Neg", V_2023_0, {"NPU"}}, {"NonMaxSuppression", V_2021_1, {"CPU", "GPU"}}, + {"NonMaxSuppression", V_2023_1, {"NPU"}}, {"NonZero", V_2021_1, {"CPU"}}, {"NonZero", V_2023_0, {"GPU"}}, {"Not", V_2021_1, {"CPU", "GPU"}}, {"Not", V_2020_4, {"CPU", "GPU"}}, + {"Not", V_2023_1, {"NPU"}}, {"OneHot", V_2020_4, {"CPU", "GPU"}}, + {"OneHot", V_2023_1, {"NPU"}}, {"Or", V_2022_1, {"CPU", "GPU"}}, + {"Or", V_2023_1, {"NPU"}}, {"Pad", V_2020_4, {"CPU", "GPU"}}, - {"Pad", V_2023_0, {"VPUX"}}, + {"Pad", V_2023_0, {"NPU"}}, {"Pow", V_2020_4, {"CPU", "GPU"}}, - {"Pow", V_2023_0, {"VPUX"}}, + {"Pow", V_2023_0, {"NPU"}}, {"PRelu", V_2020_4, {"CPU", "GPU"}}, - {"PRelu", V_2023_0, {"VPUX"}}, + {"PRelu", V_2023_0, {"NPU"}}, {"QLinearMatMul", V_2022_3, {"CPU"}}, + // {"QLinearMatMul", V_2023_1, {"NPU"}}, {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, - {"QuantizeLinear", V_2023_0, {"VPUX"}}, + {"QuantizeLinear", V_2023_0, {"NPU"}}, + {"RNN", V_2023_1, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_1, {"NPU"}}, {"RandomNormal", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormal", V_2023_1, {"NPU"}}, {"Range", V_2022_1, {"CPU", "GPU"}}, - {"Range", V_2023_0, {"VPUX"}}, + {"Range", V_2023_0, {"NPU"}}, {"Reciprocal", V_2020_4, {"CPU", "GPU"}}, - {"Reciprocal", V_2023_0, {"VPUX"}}, + {"Reciprocal", V_2023_0, {"NPU"}}, {"ReduceL1", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL1", V_2023_1, {"NPU"}}, {"ReduceL2", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL2", V_2023_1, {"NPU"}}, {"ReduceLogSum", V_2020_4, {"CPU"}}, {"ReduceLogSum", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSum", V_2023_1, {"NPU"}}, {"ReduceLogSumExp", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSumExp", V_2023_1, {"NPU"}}, {"ReduceMax", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMax", V_2023_1, {"NPU"}}, {"ReduceMean", V_2020_4, {"CPU", "GPU"}}, - {"ReduceMean", V_2023_0, {"VPUX"}}, + {"ReduceMean", V_2023_0, {"NPU"}}, {"ReduceMin", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMin", V_2023_1, {"NPU"}}, {"ReduceProd", V_2020_4, {"CPU"}}, {"ReduceProd", V_2022_1, {"GPU"}}, + {"ReduceProd", V_2023_1, {"NPU"}}, {"ReduceSum", V_2020_4, {"CPU", "GPU"}}, + // {"ReduceSum", V_2023_1, {"NPU"}}, {"ReduceSumSquare", V_2020_4, {"CPU"}}, {"ReduceSumSquare", V_2022_1, {"CPU", "GPU"}}, + {"ReduceSumSquare", V_2023_1, {"NPU"}}, {"Relu", V_2020_4, {"CPU", "GPU"}}, - {"Relu", V_2023_0, {"VPUX"}}, + {"Relu", V_2023_0, {"NPU"}}, {"Resize", V_2020_4, {"CPU"}}, {"Resize", V_2022_1, {"GPU"}}, + {"Resize", V_2023_1, {"NPU"}}, {"Reshape", V_2020_4, {"CPU", "GPU"}}, - {"Reshape", V_2023_0, {"VPUX"}}, + {"Reshape", V_2023_0, {"NPU"}}, {"ReverseSequence", V_2022_1, {"CPU", "GPU"}}, {"RoiAlign", V_2021_1, {"CPU", "GPU"}}, + {"RoiAlign", V_2023_1, {"NPU"}}, {"Round", V_2021_4, {"CPU", "GPU"}}, + {"Round", V_2023_1, {"NPU"}}, {"Scatter", V_2022_1, {"CPU", "GPU"}}, + {"Scatter", V_2023_1, {"NPU"}}, {"ScatterElements", V_2022_1, {"CPU", "GPU"}}, + {"ScatterElements", V_2023_1, {"NPU"}}, {"ScatterND", V_2022_1, {"CPU", "GPU"}}, + {"ScatterND", V_2023_1, {"NPU"}}, {"Selu", V_2020_4, {"CPU", "GPU"}}, + {"Selu", V_2023_1, {"NPU"}}, {"Shape", V_2020_4, {"CPU", "GPU"}}, - {"Shape", V_2023_0, {"VPUX"}}, + {"Shape", V_2023_0, {"NPU"}}, {"Shrink", V_2022_1, {"CPU", "GPU"}}, - {"Shrink", V_2023_0, {"VPUX"}}, + {"Shrink", V_2023_0, {"NPU"}}, {"Sigmoid", V_2020_4, {"CPU", "GPU"}}, - {"Sigmoid", V_2023_0, {"VPUX"}}, + {"Sigmoid", V_2023_0, {"NPU"}}, {"Sign", V_2020_4, {"CPU"}}, {"Sign", V_2022_1, {"GPU"}}, - {"Sign", V_2023_0, {"VPUX"}}, + {"Sign", V_2023_0, {"NPU"}}, {"Sin", V_2022_1, {"CPU", "GPU"}}, - {"Sin", V_2023_0, {"VPUX"}}, + {"Sin", V_2023_0, {"NPU"}}, {"Sinh", V_2020_4, {"CPU"}}, + {"Sinh", V_2023_1, {"NPU"}}, {"Size", V_2022_1, {"CPU", "GPU"}}, + {"Size", V_2023_1, {"NPU"}}, {"Slice", V_2020_4, {"CPU", "GPU"}}, - {"Slice", V_2023_0, {"VPUX"}}, + {"Slice", V_2023_0, {"NPU"}}, {"Softmax", V_2020_4, {"CPU", "GPU"}}, - {"Softmax", V_2023_0, {"VPUX"}}, + {"Softmax", V_2023_0, {"NPU"}}, {"Softplus", V_2022_1, {"CPU", "GPU"}}, - {"Softplus", V_2023_0, {"VPUX"}}, + {"Softplus", V_2023_0, {"NPU"}}, {"Softsign", V_2022_1, {"CPU", "GPU"}}, {"SpaceToDepth", V_2020_4, {"CPU", "GPU"}}, - {"SpaceToDepth", V_2023_0, {"VPUX"}}, + {"SpaceToDepth", V_2023_0, {"NPU"}}, {"Split", V_2020_4, {"CPU", "GPU"}}, - {"Split", V_2023_0, {"VPUX"}}, + {"Split", V_2023_0, {"NPU"}}, {"Sqrt", V_2020_4, {"CPU", "GPU"}}, - {"Sqrt", V_2023_0, {"VPUX"}}, + {"Sqrt", V_2023_0, {"NPU"}}, {"Squeeze", V_2020_4, {"CPU", "GPU"}}, - {"Squeeze", V_2023_0, {"VPUX"}}, + {"Squeeze", V_2023_0, {"NPU"}}, {"Softsign", V_2020_4, {"CPU"}}, {"Sub", V_2020_4, {"CPU", "GPU"}}, - {"Sub", V_2023_0, {"VPUX"}}, + {"Sub", V_2023_0, {"NPU"}}, {"Sum", V_2020_4, {"CPU", "GPU"}}, - {"Sum", V_2023_0, {"VPUX"}}, + {"Sum", V_2023_0, {"NPU"}}, {"Tan", V_2020_4, {"CPU", "GPU"}}, + {"Tan", V_2023_1, {"NPU"}}, {"Tanh", V_2020_4, {"CPU", "GPU"}}, - {"Tanh", V_2023_0, {"VPUX"}}, + {"Tanh", V_2023_0, {"NPU"}}, {"ThresholdedRelu", V_2022_1, {"CPU", "GPU"}}, - {"ThresholdedRelu", V_2023_0, {"VPUX"}}, + {"ThresholdedRelu", V_2023_0, {"NPU"}}, {"Tile", V_2021_3, {"CPU", "GPU"}}, - {"Tile", V_2023_0, {"VPUX"}}, + {"Tile", V_2023_0, {"NPU"}}, {"Transpose", V_2020_4, {"CPU", "GPU"}}, - {"Transpose", V_2023_0, {"VPUX"}}, + {"Transpose", V_2023_0, {"NPU"}}, {"Trilu", V_2023_0, {"CPU", "GPU"}}, + {"Trilu", V_2023_1, {"NPU"}}, {"TopK", V_2020_4, {"CPU", "GPU"}}, - {"TopK", V_2023_0, {"VPUX"}}, + {"TopK", V_2023_0, {"NPU"}}, + {"Upsample", V_2020_4, {"CPU", "GPU"}}, {"Unsqueeze", V_2020_4, {"CPU", "GPU"}}, - {"Unsqueeze", V_2023_0, {"VPUX"}}, - {"Upsample", V_2021_1, {"CPU"}}, - {"Upsample", V_2021_4, {"GPU"}}, - {"Upsample", V_2023_0, {"VPUX"}}, + {"Unsqueeze", V_2023_0, {"NPU"}}, {"Where", V_2022_1, {"CPU", "GPU"}}, - {"Where", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Where", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Xor", V_2022_1, {"CPU", "GPU"}}, + {"Xor", V_2023_1, {"NPU"}}, }; void DataOps::populate_types_supported() { - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_initializer_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_initializer_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_initializer_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_initializer_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_vpu_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_npu_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_cpu_.insert(std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_cpu_.insert( + std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_gpu_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_gpu_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_gpu_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_gpu_.insert(std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_gpu_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_gpu_.insert( + std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); } void DataOps::populate_op_mode_supported() { @@ -349,10 +439,10 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}}); - no_dimension_supported_.push_back({"Greater", V_2023_0, {"VPUX"}}); + no_dimension_supported_.push_back({"Greater", V_2023_0, {"NPU"}}); no_dimension_supported_.push_back({"Less", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); - no_dimension_supported_.push_back({"Max", V_2023_0, {"VPUX"}}); + no_dimension_supported_.push_back({"Max", V_2023_0, {"NPU"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"QuantizeLinear", V_2021_4, {"All"}}); @@ -382,11 +472,14 @@ void DataOps::populate_op_mode_supported() { { UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3}, [this](const Node* node, const InitializedTensorSet&) { - // Abs is not supproted with INT8 or INT32 as input data type on GPU - if (device_id_.find("GPU") != std::string::npos) { + // Abs is not supproted with INT8 or INT32 as input data type on GPU and NPU + if ((device_id_.find("GPU") != std::string::npos) || + (device_id_.find("NPU") != std::string::npos)) { for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || - node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || + node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) return true; } } @@ -399,11 +492,14 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { // tensor type does not support select last index auto& attributes = node->GetAttributes(); - auto last_index_arg = attributes.count("select_last_index") > 0 ? attributes.at("select_last_index").i() : 0; + auto last_index_arg = + attributes.count("select_last_index") > 0 ? attributes.at("select_last_index").i() + : 0; if (last_index_arg != 0) return true; // tensor type supports float as input for argmax and argmin - if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) + if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) return true; return false; }}; @@ -415,7 +511,8 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { if (device_id_.find("GPU") != std::string::npos) { // int64 data type is not supported on GPU - const bool data_is_int64 = node->InputDefs()[0]->Type()->find("int64") != std::string::npos; + const bool data_is_int64 = + node->InputDefs()[0]->Type()->find("int64") != std::string::npos; return data_is_int64; } return false; @@ -506,9 +603,12 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { auto x_data_type = node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); auto y_data_type = node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - // currently both inputs with int32 are not supported and also both input datatypes should be same - const bool A_is_int32 = node->InputDefs()[0]->Type()->find("int32") != std::string::npos; - const bool B_is_int32 = node->InputDefs()[1]->Type()->find("int32") != std::string::npos; + // currently both inputs with int32 are not supported + // and also both input datatypes should be same + const bool A_is_int32 = + node->InputDefs()[0]->Type()->find("int32") != std::string::npos; + const bool B_is_int32 = + node->InputDefs()[1]->Type()->find("int32") != std::string::npos; if ((A_is_int32 && B_is_int32) || (x_data_type != y_data_type)) return true; } @@ -589,11 +689,13 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { auto slope = node->InputDefs()[1]; // PRelu slope has to be an initializer or needs to come from a constant node - if (initializers.count(slope->Name())) + if (initializers.count(slope->Name())) { return false; - else { - for (auto input_node = node->InputNodesBegin(); input_node != node->InputNodesEnd(); ++input_node) { - if (GetInputCount(this->graph_viewer_.GetNode((*input_node).Index()), initializers) == 0) + } else { + for (auto input_node = node->InputNodesBegin(); + input_node != node->InputNodesEnd(); ++input_node) { + if (GetInputCount( + this->graph_viewer_.GetNode((*input_node).Index()), initializers) == 0) return false; } } @@ -603,12 +705,12 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"PRelu", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_arg = node->InputDefs()[1]; auto shape = input_arg->Shape(); // Reshape op with empty dim is Rejected for Myriad - //[TODO] Is this condition required anymore with Myriad removed? + // [TODO] Is this condition required anymore with Myriad removed? if (shape != nullptr) { for (const auto& dim : input_arg->Shape()->dim()) { if (utils::HasDimValue(dim) && dim.dim_value() == 0) @@ -638,7 +740,8 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { // INT32 dataype is not supported as input for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) return true; } } @@ -650,9 +753,11 @@ void DataOps::populate_op_mode_supported() { UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3}, [this](const Node* node, const InitializedTensorSet&) { if (device_id_.find("GPU") != std::string::npos) { - auto output_data_type = node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + auto output_data_type = + node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); // If the output of ScatterND op is BOOL, it is rejected for GPU. - if (output_data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) + if (output_data_type == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) return true; } return false; @@ -666,7 +771,8 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { // If the Input of Shrink op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) return true; } return false; @@ -714,10 +820,11 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Squeeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze - // If axes is an input, then we cannot produce a static graph. Conversion fails in convert_function_to_cnn_network. + // If axes is an input, then we cannot produce a static graph. + // Conversion fails in convert_function_to_cnn_network. for (size_t i = 0; i < node->InputDefs().size(); i++) { if (node->InputDefs()[i]->Name() == "axes") { return true; @@ -728,14 +835,15 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); if (upsample_attr.count("scales") > 0) { auto& upsample_arg = upsample_attr.at("scales"); auto float_size = upsample_arg.floats_size(); - if (float_size > 2 && (upsample_arg.floats(0) != 1.f || upsample_arg.floats(1) != 1.f)) { + if (float_size > 2 && + (upsample_arg.floats(0) != 1.f || upsample_arg.floats(1) != 1.f)) { return true; } } @@ -750,9 +858,12 @@ void DataOps::populate_op_mode_supported() { } } // x_arg supports only float, int8 and float16 type - if ((x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || - (x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || - (x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)) { + if ((x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || + (x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || + (x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)) { return false; } else { return true; @@ -849,9 +960,9 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { } else { auto dtype = type_proto->tensor_type().elem_type(); - if (device_id_.find("VPUX") != std::string::npos || device_id_.find("HETERO") != std::string::npos || + if (device_id_.find("NPU") != std::string::npos || device_id_.find("HETERO") != std::string::npos || device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { - for (auto const& var : supported_types_vpu_) { + for (auto const& var : supported_types_npu_) { if ((var.first <= version_id_) && (var.second == dtype)) { return true; @@ -1079,7 +1190,9 @@ bool DataOps::node_is_supported(const std::mapsecond.find(optype) == opset->second.end() && op_fun == ops_supported_as_function.end()) { #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { - std::cout << "The operator is not available in OpenVINO ngraph operators list nor the operator is a special ONNX function" << std::endl; + std::cout << "The operator is not available in OpenVINO ngraph operators list" + << "nor the operator is a special ONNX function" + << std::endl; } #endif return false; @@ -1095,10 +1208,12 @@ std::vector DataOps::GetUnsupportedNodeIndices(std::unordered_setForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, bool is_input) { - if(is_input && this->graph_viewer_.GetAllInitializedTensors().count(node_arg.Name())) { + graph_viewer_.GetNode(node_idx)->ForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, + bool is_input) { + if (is_input && this->graph_viewer_.GetAllInitializedTensors().count(node_arg.Name())) { ng_required_initializers.insert(node_arg.Name()); - } }, true); + } }, + true); } else { unsupported_nodes_idx.push_back(node_idx); } @@ -1110,7 +1225,8 @@ bool DataOps::IsOpSupportedOnlyInModel(std::string name) { return ops_supported_only_in_model.find(name) != ops_supported_only_in_model.end(); } -bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, const Node* node) { +bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, + const Node* node) { if (node->OpType() == "Reshape") { const auto& shape_arg = node->InputDefs()[1]; if (ng_required_initializers.find(shape_arg->Name()) == ng_required_initializers.end()) { @@ -1119,15 +1235,20 @@ bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& } else if (node->OpType() == "Expand") { // nGraph only supports constant shape input values const auto& output = node->OutputDefs()[0]; - if (output->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) + if (output->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) return true; } else if (node->OpType() == "RoiAlign") { using onnx_dtype = ONNX_NAMESPACE::TensorProto_DataType; - onnx_dtype input_0_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype input_1_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype input_2_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[2]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype output_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_0_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_1_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_2_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[2]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype output_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); if ((input_0_data_type != onnx_dtype::TensorProto_DataType_FLOAT16) || (input_1_data_type != onnx_dtype::TensorProto_DataType_FLOAT16) || diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index cc968d02ea644..a5aa3f825602c 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -3,6 +3,11 @@ #pragma once #include +#include +#include +#include +#include +#include namespace onnxruntime { namespace openvino_ep { @@ -47,7 +52,7 @@ class DataOps { std::multimap op_list_; std::vector subgraph_supported_; std::vector no_dimension_supported_; - std::set supported_types_vpu_; + std::set supported_types_npu_; std::set supported_types_cpu_; std::set supported_types_gpu_; std::set supported_types_initializer_; @@ -64,14 +69,16 @@ class DataOps { const NodeIndex node_idx); public: - DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, std::string dev_id) : graph_viewer_(graph_viewer_param), version_id_(ver), device_id_(dev_id) { + DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, std::string dev_id) + : graph_viewer_(graph_viewer_param), version_id_(ver), device_id_(dev_id) { populate_op_mode_supported(); populate_types_supported(); } virtual std::vector GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers); virtual bool IsOpSupportedOnlyInModel(std::string name); - virtual bool SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, const Node* node); + virtual bool SpecialConditionForClusterSizeOne( + std::unordered_set& ng_required_initializers, const Node* node); virtual bool DoNotOmitSubGraph(const std::string& name); virtual bool InsertNode(const std::string& name); VersionNum GetVersion() const { return version_id_; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index be509b6743621..74369d39b9a24 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License #include "core/providers/shared_library/provider_api.h" +#include "utils.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245 5208) @@ -113,7 +114,8 @@ std::map> GetNgSupportedOps(const int onnx_op * supported_cluster + (UNsupported_node + rest_of_the_graph). This functions returns vector of all supported_clusters by nGraph */ std::vector> -GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes) { +GetPartitionedClusters(const std::vector& topological_order, + const std::vector& unsupported_nodes) { std::vector> ng_clusters; auto prev = topological_order.begin(); @@ -140,7 +142,10 @@ GetPartitionedClusters(const std::vector& topological_order, const st return ng_clusters; } -void IdentifyConnectedNodes(const GraphViewer& graph_viewer, NodeIndex curr_node_index, std::vector& cluster, std::vector& sub_cluster) { +void IdentifyConnectedNodes(const GraphViewer& graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster) { if (std::find(cluster.begin(), cluster.end(), curr_node_index) == cluster.end()) return; @@ -205,7 +210,8 @@ void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const auto& ext_node = graph_viewer.GetNode((*it).Index()); if (std::find(cluster.begin(), cluster.end(), ext_node->Index()) == cluster.end()) { - // Node is external to this_cluster. Search through its inputs to find the output that is generated by this_cluster. + // Node is external to this_cluster. Search through its inputs to + // find the output that is generated by this_cluster. std::set ext_node_inputs; ext_node->ForEachDef( [&ext_node_inputs](const NodeArg& arg, bool is_input) { diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index 70f6954ea991c..c256cde97956e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -1,5 +1,15 @@ // Copyright (C) 2019-2022 Intel Corporation // Licensed under the MIT License +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include namespace onnxruntime { namespace openvino_ep { @@ -18,9 +28,14 @@ int GetOnnxOpSet(const GraphViewer& graph_viewer); std::map> GetNgSupportedOps(const int onnx_opset); std::vector> -GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes); - -void IdentifyConnectedNodes(const GraphViewer& graph_viewer, NodeIndex curr_node_index, std::vector& cluster, std::vector& sub_cluster); +GetPartitionedClusters( + const std::vector& topological_order, const std::vector& unsupported_nodes); + +void IdentifyConnectedNodes( + const GraphViewer& graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster); std::vector> GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector>& clusters); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 9e59883478227..df4dd55417755 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1432,7 +1432,7 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - ov_options_converted_map["enable_vpu_fast_compile"] = legacy_ov_options->enable_vpu_fast_compile; + ov_options_converted_map["enable_npu_fast_compile"] = legacy_ov_options->enable_npu_fast_compile; if (legacy_ov_options->device_id != nullptr) ov_options_converted_map["device_id"] = legacy_ov_options->device_id; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7faca3b4681b8..2027b592326df 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -813,10 +813,10 @@ std::unique_ptr CreateExecutionProviderInstance( if (option.first == "device_type") { OV_provider_options_map[option.first] = option.second; continue; - } else if (option.first == "enable_vpu_fast_compile") { + } else if (option.first == "enable_npu_fast_compile") { if (!(option.second == "True" || option.second == "true" || option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_vpu_fast_compile: ", option.second); + ORT_THROW("Invalid value passed for enable_npu_fast_compile: ", option.second); } OV_provider_options_map[option.first] = option.second; } else if (option.first == "enable_opencl_throttling") { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 5bb6bcc38b6fe..a5bcbce89bac6 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -60,11 +60,11 @@ struct OrtStatus { #elif OPENVINO_CONFIG_GPU_FP16 #define BACKEND_OPENVINO "-OPENVINO_GPU_FP16" -#elif OPENVINO_CONFIG_VPUX_FP16 -#define BACKEND_OPENVINO "-OPENVINO_VPUX_FP16" +#elif OPENVINO_CONFIG_NPU_FP16 +#define BACKEND_OPENVINO "-OPENVINO_NPU_FP16" -#elif OPENVINO_CONFIG_VPUX_U8 -#define BACKEND_OPENVINO "-OPENVINO_VPUX_U8" +#elif OPENVINO_CONFIG_NPU_U8 +#define BACKEND_OPENVINO "-OPENVINO_NPU_U8" #elif OPENVINO_CONFIG_MULTI #define BACKEND_OPENVINO "-OPENVINO_MULTI" diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index b1a04a00e89b1..6d075fec997b5 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -60,7 +60,7 @@ namespace perftest { "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" - "\t [OpenVINO only] [enable_vpu_fast_compile]: Optionally enabled to speeds up the model's compilation on VPU device targets.\n" + "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" "\t [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" @@ -72,7 +72,7 @@ namespace perftest { "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" "\t [Usage]: -e -i '| |'\n\n" - "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_vpu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" "\t [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 41a1eafebbb50..b7a111783fc94 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -240,8 +240,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (key == "device_type") { std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16", - "VPUX_FP16", "VPUX_U8"}; + "GPU.0_FP16", "GPU.1_FP16"}; if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { ov_options[key] = value; } else if (value.find("HETERO:") == 0) { @@ -254,17 +253,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW( "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " - "'GPU.0_FP16', 'GPU.1_FP16', 'VPUX_FP16', 'VPUX_U8' or from" + "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); } } else if (key == "device_id") { ov_options[key] = value; - } else if (key == "enable_vpu_fast_compile") { + } else if (key == "enable_npu_fast_compile") { if (value == "true" || value == "True" || value == "false" || value == "False") { ov_options[key] = value; } else { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_vpu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_npu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); } } else if (key == "enable_opencl_throttling") { if (value == "true" || value == "True" || @@ -299,7 +298,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ov_options[key] = value; } } else { - ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_vpu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); + ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); } } session_options.AppendExecutionProvider("OpenVINO", ov_options); diff --git a/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc index e37206d6aebf2..b7cead66bd7fb 100644 --- a/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc @@ -143,7 +143,7 @@ void L1NormalizationWithZeroNorm() { vector expected_output = {0.5f, 0.5f, 0.f, 0.f}; test.AddOutput("Y", input_dims, expected_output); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(LpNormalizationTest, L1NormalizationWithZeroNorm) { @@ -163,7 +163,7 @@ void L2NormalizationWithZeroNorm() { vector expected_output = {1.f, 0.f, 0.f, 0.f}; test.AddOutput("Y", input_dims, expected_output); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(LpNormalizationTest, L2NormalizationWithZeroNorm) { diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index d1a523b1eecf9..b9875b9553a55 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -762,7 +762,7 @@ TEST(RNNTest, RNN_invalid_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // the CUDA RNN version allows the invalid sequence lengths, so disable testing on CUDA and TensorRT - test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); }; // should batch batch_size to be valid @@ -860,7 +860,7 @@ TEST(RNNTest, RNN_bidirectional_with_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(RNNTest, RNN_with_invalid_activation_load_failure) { diff --git a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc index c95ac1603a317..c3d91100605e9 100644 --- a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc @@ -66,7 +66,7 @@ TEST(CompressTest, Compress_3dims_has_extra_condition) { // has condition length = 3 > input_dim[axis] = 2 test.AddInput("condition", {3}, {0, 1, 1}); test.AddOutput("output", {2, 1, 3}, {4.0f, 5.0f, 6.0f, 10.0f, 11.0f, 12.0f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(CompressTest, Compress_3dims_has_extra_input) { diff --git a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc index 2120da604f94a..d2aa5dd428fec 100644 --- a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc @@ -99,7 +99,7 @@ TEST(TensorOpTest, Unsqueeze_scalar_2) { test.AddInput("input", {}, std::vector{1.0f}); test.AddInput("axes", {2}, std::vector{0, -1}, axes_is_initializer); test.AddOutput("output", {1, 1}, std::vector{1.0f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); }; run_test(false); run_test(true); diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index ecf4b001eec68..c48b07422d452 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -140,6 +140,9 @@ def create_backend_test(test_name=None): if backend.supports_device("OPENVINO_CPU_FP16"): current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP16") + if backend.supports_device("OPENVINO_NPU_FP16"): + current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU_FP16") + if backend.supports_device("OPENVINO"): current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_opset18") diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 44db7c0078cfc..c552ec3aea72d 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -521,6 +521,10 @@ "test_scan_sum_cpu", // Disabled due to output mismatch with tolerance. "test_scan9_sum_cpu" // Disabled due to output mismatch with tolerance. ], + "current_failing_tests_OPENVINO_NPU_FP16": [ + "^test_prelu_broadcast", + "test_loop11_cpu" + ], "current_failing_tests_OPENVINO_opset18": [ // pending opset 18 support, RUNTIME_EXCEPTION : Encountered unknown exception in Initialize() "^test_center_crop_pad_crop_axes_chw", diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 806e536cb4ddb..a992da8ff993e 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -66,15 +66,13 @@ def _str_to_bool(s): def _openvino_verify_device_type(device_read): - choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16", "VPUX_FP16", "VPUX_U8"] + choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16"] choices1 = [ "CPU_FP32_NO_PARTITION", "CPU_FP16_NO_PARTITION", "GPU_FP32_NO_PARTITION", "GPU_FP16_NO_PARTITION", - "VPUX_FP16_NO_PARTITION", - "VPUX_U8_NO_PARTITION", ] status_hetero = True res = False @@ -89,7 +87,7 @@ def _openvino_verify_device_type(device_read): if len(comma_separated_devices) < 2: print("At least two devices required in Hetero/Multi/Auto Mode") status_hetero = False - dev_options = ["CPU", "GPU", "VPUX"] + dev_options = ["CPU", "GPU"] for dev in comma_separated_devices: if dev not in dev_options: status_hetero = False @@ -100,7 +98,7 @@ def invalid_hetero_build(): print("specify the keyword HETERO or MULTI or AUTO followed by the devices ") print("in the order of priority you want to build\n") print("The different hardware devices that can be added in HETERO or MULTI or AUTO") - print("are ['CPU','GPU', 'VPUX'] \n") + print("are ['CPU','GPU'] \n") print("An example of how to specify the hetero build type. Ex: HETERO:GPU,CPU \n") print("An example of how to specify the MULTI build type. Ex: MULTI:GPU,CPU \n") print("An example of how to specify the AUTO build type. Ex: AUTO:GPU,CPU \n") @@ -1158,8 +1156,6 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_GPU_FP16=" + ("ON" if args.use_openvino == "GPU_FP16" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP32=" + ("ON" if args.use_openvino == "CPU_FP32" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP16=" + ("ON" if args.use_openvino == "CPU_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_FP16=" + ("ON" if args.use_openvino == "VPUX_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_U8=" + ("ON" if args.use_openvino == "VPUX_U8" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU_FP32_NP=" + ("ON" if args.use_openvino == "GPU_FP32_NO_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU_FP16_NP=" @@ -1168,9 +1164,6 @@ def generate_build_tree( + ("ON" if args.use_openvino == "CPU_FP32_NO_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP16_NP=" + ("ON" if args.use_openvino == "CPU_FP16_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_FP16_NP=" - + ("ON" if args.use_openvino == "VPUX_FP16_NP_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_U8_NP=" + ("ON" if args.use_openvino == "VPUX_U8_NP_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_HETERO=" + ("ON" if args.use_openvino.startswith("HETERO") else "OFF"), "-Donnxruntime_USE_OPENVINO_DEVICE=" + (args.use_openvino), "-Donnxruntime_USE_OPENVINO_MULTI=" + ("ON" if args.use_openvino.startswith("MULTI") else "OFF"), diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index cc27cdc293646..f7b68551b9c50 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -552,6 +552,7 @@ def generate_files(line_list, args): files_list.append( "" ) + else: files_list.append( "' - ) + dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\") + tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\") + for dll_element in os.listdir(dll_list_path): if dll_element.endswith("dll"): files_list.append( @@ -735,26 +720,7 @@ def generate_files(line_list, args): + args.target_architecture + '\\native" />' ) - # plugins.xml - files_list.append( - "' - ) - # usb-ma2x8x.mvcmd - # OpenVINO 2022.3 doesn't have usb-ma2x8x.mvcmd - if "2022.3" not in openvino_path: - files_list.append( - "' - ) + for tbb_element in os.listdir(tbb_list_path): if tbb_element.endswith("dll"): files_list.append( From 9e8ad398479d9c2dc0ca91a8df89e452d059f6ee Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 1 Nov 2023 08:49:33 -0700 Subject: [PATCH 20/21] Distributed Reduction (#18206) This PR implements distributed reduciton for llama 2. This version doesn't consider any cases requring re-sharding because we haven't seen any use cases. Intutive examples: - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[0]) -> [1,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[1]) -> [2,1,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [not supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[2]) -> [2,4,1]-tensor with spec=RRS[0] and device_mesh=[0,1] Algorithm: When the reduced axes are not sharded, each device can call reduction directly. The output sharding spec will be identical to input sharding spec. We currently throw when input and output sharding specs are different. Review guideline: - Check 97b8d2f for new op's schema and how new op is registered. - Read tests in 2450f93 to get faimilar with the behavior of these ops. - Check the implementation details in 753d9af. --- cmake/onnxruntime_providers_cuda.cmake | 1 + cmake/onnxruntime_rocm_hipify.cmake | 1 + .../cuda/collective/distributed_reduce.cc | 175 +++++++++ .../cuda/collective/distributed_reduce.h | 59 +++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 18 + .../core/graph/contrib_ops/collective_defs.cc | 123 +++++++ .../providers/cuda/reduction/reduction_ops.cc | 24 ++ .../python/onnxruntime_test_distributed.py | 345 ++++++++++++------ 8 files changed, 638 insertions(+), 108 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 043789c36c327..ce0c12804b08a 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -40,6 +40,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6ccf063c71290..9bc2bdd208a92 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -109,6 +109,7 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc new file mode 100644 index 0000000000000..967f30a304ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -0,0 +1,175 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reduce.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/reduction/reduction_ops.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedReduceBase::DistributedReduceBase( + const OpKernelInfo& info, + cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) { + keepdims_ = info.GetAttrOrDefault("keepdims", 1); + cudnn_reduce_op_ = cudnn_reduce_op; +}; + +template +Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const { + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& axes_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(axes_sharding_spec.HasNoShard(), + "It's not worthy to shard axes tensor. " + "If sharding axes is needed, please submit a feature request."); + + const Tensor* input_tensor = context->Input(0); + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor."); + auto axes_span = axes_tensor->DataAsSpan(); + + // Case 1: empty axes means treating this reduction as an identity. + if (axes_span.empty()) { + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + auto* output_tensor = context->Output(0, input_tensor->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData(), input_tensor->Data(), input_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + // Case 2: this is a valid reduction. Let's prepare for it. + + bool sharding_on_reduced_axes = false; + for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) { + if (*axis_it == input_sharding_spec.GetPartitionAxis()) { + sharding_on_reduced_axes = true; + break; + } + } + + if (sharding_on_reduced_axes) { + // Case 2-1: sharding on reduced axes. + ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica."); + } else { + // Case 2-2: sharding on passing-through axes or no shard. + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + onnxruntime::cuda::PrepareReduceMetadata metadata; + ORT_RETURN_IF_ERROR( + onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)); + auto output_tensor = context->Output(0, metadata.squeezed_output_dims); + + // Fast reduction is not deterministic, so sometimes we want to turn it off. + const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); + return onnxruntime::cuda::ReduceComputeCore( + /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, + /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, + enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + } + return Status::OK(); +} + +template +DistributedReduceSum::DistributedReduceSum( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_ADD){}; + +template +DistributedReduceMean::DistributedReduceMean( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_AVG){}; + +template +DistributedReduceMax::DistributedReduceMax( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_MAX){}; + +// ReduceSum +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); + +// ReduceMean +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); + +// ReduceMax +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h new file mode 100644 index 0000000000000..2939852c75c60 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReduceBase : public DistributedKernel { + public: + explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + // ONNX attribute. If true, reduced axes are retained as dimensions with size one. + // Otherwise, drop reduced axes. + bool keepdims_; + cudnnReduceTensorOp_t cudnn_reduce_op_; +}; + +template +class DistributedReduceSum final : public DistributedReduceBase { + public: + explicit DistributedReduceSum(const OpKernelInfo& info); +}; + +template +class DistributedReduceMean final : public DistributedReduceBase { + public: + explicit DistributedReduceMean(const OpKernelInfo& info); +}; + +template +class DistributedReduceMax final : public DistributedReduceBase { + public: + explicit DistributedReduceMax(const OpKernelInfo& info); +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index d51915b85095f..8e157da6cb43f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -175,6 +175,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean); #endif template <> @@ -354,6 +363,15 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 070df487a264d..8b5b561c1ad87 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -273,6 +273,129 @@ void RegisterCollectiveOps() { OpSchema::NonDifferentiable) .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceSum) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMax) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMean) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index d46ed9c245a8e..bc78e577c5052 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -614,6 +614,30 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const { diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index e0fb3979a9f55..6f691972181b5 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -7,7 +7,7 @@ import numpy as np import onnxscript from mpi4py import MPI -from onnxscript import FLOAT, INT64 +from onnxscript import FLOAT, FLOAT16, INT64 import onnxruntime as ort @@ -27,12 +27,23 @@ def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): return np.concatenate(selected_shards, axis=axis) -def translate_device_mesh_to_attrs(device_mesh: np.ndarray): +def translate_single_device_mesh(device_mesh: np.ndarray): device_mesh_shape = "[" + ",".join(str(dim) for dim in device_mesh.shape) + "]" device_mesh_elements = "[" + ",".join(str(elem) for elem in device_mesh.flat) + "]" return device_mesh_shape, device_mesh_elements +def translate_all_device_meshes(device_meshes: np.ndarray): + assert all(len(mesh.shape) == 1 for mesh in device_meshes) + device_mesh_shapes = [] + device_mesh_elements = [] + for device_mesh in device_meshes: + device_mesh_shape, device_mesh_element = translate_single_device_mesh(device_mesh) + device_mesh_shapes.append(device_mesh_shape) + device_mesh_elements.append(device_mesh_element) + return device_mesh_shapes, device_mesh_elements + + def parse_sharding_spec(spec: str): axis_conditions = [] sharding_device_axes = [] @@ -90,29 +101,13 @@ def _check_distributed_reshape( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -134,11 +129,11 @@ def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = np.reshape(data_tensor, shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_reshape_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -176,9 +171,9 @@ def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self): 3, ), target_shape=(6,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]",), ) @@ -191,9 +186,9 @@ def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self): 4, ), target_shape=(8,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]",), ) @@ -210,9 +205,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self): 2, 15, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -229,9 +224,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self): 2, 20, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -248,9 +243,9 @@ def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self): 2, 18, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) # Two axis fusion. @@ -268,9 +263,9 @@ def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self): 2, 3, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -283,9 +278,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self): 1, 16, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -298,9 +293,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -313,9 +308,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self): 4, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -328,9 +323,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self): 8, 2, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -343,9 +338,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self): 16, 1, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -359,9 +354,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self) 1, 16, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -375,9 +370,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -390,9 +385,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self): 4, 4, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -405,9 +400,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self): 8, 2, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -420,9 +415,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self) 16, 1, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -444,9 +439,9 @@ def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -471,9 +466,9 @@ def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rr 64, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]R",), ) @@ -495,9 +490,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrr_01_shape_21_4906_rr_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -519,9 +514,9 @@ def test_reshape_two_axis_fusion_shape_21_4096_rrr_01_shape_3_7_4906_rr_01(self) 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRR",), ) @@ -546,9 +541,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_64_rsrr_01_shape_192_7_64_srr_0101 7, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -573,9 +568,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_7_srr_010101_shape_3_64_7_7_ 7, 7, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -600,9 +595,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101 7, 7, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -627,9 +622,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_6 7, 64, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -654,9 +649,9 @@ def test_reshape_two_axis_fusion_shape_3_7_64_64_rrsr_01_shape_3_7_4096_rrs_01(s 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -678,9 +673,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -690,29 +685,16 @@ def _check_distributed_expand( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -734,11 +716,11 @@ def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = data_tensor * np.ones(shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_expand_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -780,9 +762,9 @@ def test_expand_sharded_on_expanded_axis(self): 8, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -799,9 +781,9 @@ def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self): 8, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -818,9 +800,9 @@ def test_expand_replicated_on_expanded_axis(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -837,12 +819,12 @@ def test_expand_with_pass_through_sharding_spec(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=( "S[0]R", "R", ), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -863,13 +845,160 @@ def test_expand_in_tiny_llama(self): 256, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) +class TestDistributedReduce(unittest.TestCase): + def _check_distributed_reduce( + self, + keepdims: int, + dtype: np.dtype, + shape: Tuple[int, ...], + axes: Tuple[int, ...], + input_device_meshes: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshes: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) + + @onnxscript.script() + def distributed_reduce_sum_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceSum( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_max_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMax( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMean( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + + for onnx_func, np_func in zip( + [distributed_reduce_sum_instance, distributed_reduce_max_instance, distributed_reduce_mean_instance], + [np.sum, np.maximum.reduce, np.mean], + ): + data = np.random.randint(4, size=shape).astype(dtype) + expected = np_func(data, axis=axes, keepdims=bool(keepdims)) + + assert len(input_shard_specs) == 2 and len(input_device_meshes) == 2, "Reduce has two inputs." + assert "S" not in input_shard_specs[1], "Tensor `axes` should not be sharded." + assert len(output_shard_specs) == 1 and len(output_device_meshes) == 1, "Reduce has only one output." + + local_data = shard_tensor_per_spec(data, rank, input_shard_specs[0], input_device_meshes[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) + + if dtype == np.float32: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + elif dtype == np.int64: + onnx_model = onnx_func.to_model_proto( + input_types=[INT64[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[INT64[tuple(local_expected.shape)]], + ) + elif dtype == np.float16: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT16[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT16[tuple(local_expected.shape)]], + ) + else: + raise RuntimeError(f"Unsupported dtype: {dtype}") + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data, + "axes_tensor": np.array(axes, dtype=np.int64), + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_reduce(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(0,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_reduce_sharded(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(1,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]R", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2]. From a2e9ba72d5a5f61e1324ffc2a80d748d01be9120 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Wed, 1 Nov 2023 15:34:51 -0700 Subject: [PATCH 21/21] [JS/Web]Added FusedConv. (#17766) ### Description Added FusedConv and FusedConvTranspose ### Motivation and Context Improve performance --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + .../webgpu/ops/3rd-party/activation_util.ts | 4 +- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 5 +- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 4 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 6 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 37 +++--- js/web/test/data/ops/fused-conv.jsonc | 112 ++++++++++++++++++ onnxruntime/contrib_ops/js/fused_conv.cc | 20 ++++ .../contrib_ops/js/js_contrib_kernels.cc | 5 +- .../core/optimizer/conv_activation_fusion.cc | 31 ++++- .../core/optimizer/conv_add_act_fusion.cc | 7 +- .../core/optimizer/graph_transformer_utils.cc | 13 +- .../selector_action_transformer.cc | 20 ++-- .../selector_action_transformer.h | 17 ++- .../core/providers/js/operators/conv.cc | 2 + .../core/providers/js/operators/conv.h | 78 ++++++++---- .../providers/js/operators/conv_transpose.cc | 2 + .../providers/js/operators/conv_transpose.h | 55 ++++++--- .../test/optimizer/graph_transform_test.cc | 13 +- 21 files changed, 339 insertions(+), 98 deletions(-) create mode 100644 js/web/test/data/ops/fused-conv.jsonc create mode 100644 onnxruntime/contrib_ops/js/fused_conv.cc diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 44003021293b0..5b94a4a510934 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -40,6 +40,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherElements | ai.onnx(11-12,13+) | | | Gelu | com.microsoft(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 40309c1849bcc..a4d51e68b6a25 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -67,6 +67,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['Gelu', [unaryOps.gelu]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index 22b91d680a9b4..6481a6b21d723 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -41,12 +41,12 @@ export const activationFnSnippet = if (!activation) { return ''; } - // TODO: add implementations return ''; }; export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => ` ${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''} - ${activation ? 'value = activation(value, coords);' : ''} + // TODO uncomment the following line when activation is supported above. + // ${activation ? 'value = activation(value, coords);' : ''} `; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 01ddca520deed..fbb936a045b9c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -242,8 +242,9 @@ export const createConv2DMatMulProgramInfo = ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2], t)} + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, + attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1], + elementsSize[2], t)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 840360223c75a..a95d3830f34eb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -236,7 +236,9 @@ export const createConv2DTransposeMatMulProgramInfo = const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + conv2dTransposeCommonSnippet( + isChannelsLast, hasBias, attributes.activation.toLowerCase() as Activation, false, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 1032869412462..0a0f29db6a494 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo} from '../../types'; import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; -import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -440,7 +440,7 @@ export const createMatmulProgramInfo = const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, isVec4); // TODO: fine tune size const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; @@ -473,8 +473,8 @@ export const createMatmulProgramInfo = const dimBOuter: i32 = ${dimBOuter}; const dimInner: i32 = ${dimInner}; ${shaderHelper.declareVariables(...inputVariables, output)} - ${declareFunctions} ${activationFunction} + ${declareFunctions} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 7abf022928ade..8bfa722dd0909 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -7,7 +7,7 @@ import {ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActicationSnippet} from './fuse-utils'; +import {getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -22,7 +22,7 @@ export const createGroupedConvProgramInfo = const wShape = inputs[1].dims; const outputChannelsPerGroup = wShape[0] / attributes.group; - const {activationFunction, applyActivation} = getActicationSnippet(attributes); + const {activationFunction, applyActivation} = getActivationSnippet(attributes); const isChannelLast = attributes.format === 'NHWC'; const outputShape = calculateOutputShape( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 92105859a8c0e..956ef18eb5cfb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -10,24 +10,25 @@ export interface InternalActivationAttributes { readonly activationCacheKey: string; } -export const getActicationSnippet = - (attributes: InternalActivationAttributes): {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; - case 'Sigmoid': - return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; - case 'Clip': - return { - activationFunction: - `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, isVec4 = false): { + activationFunction: string; applyActivation: string; +} => { + switch (attributes.activation) { + case 'Relu': + return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; + case 'Sigmoid': + return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; + case 'Clip': + return { + activationFunction: `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, + applyActivation: isVec4 ? 'value = clamp(value, vec4(clip_min_), vec4(clip_max_));' : + 'value = clamp(value, clip_min_, clip_max_);' + }; + // TODO: adding other activations that can be fused. + default: + return {activationFunction: '', applyActivation: ''}; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc new file mode 100644 index 0000000000000..812e9d7c2def0 --- /dev/null +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "conv without bias addition A", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "T[1]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv without bias addition A", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 11 }, + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "T[3]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/contrib_ops/js/fused_conv.cc b/onnxruntime/contrib_ops/js/fused_conv.cc new file mode 100644 index 0000000000000..76402f0681976 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fused_conv.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/operators/conv.h" +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + onnxruntime::js::Conv); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 4641b006a7785..24d327576ecd9 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -11,6 +11,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -23,7 +24,9 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index c090ab2a6cc9b..d27603e4ab3a1 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -4,7 +4,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include - +#include #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas.h" @@ -174,9 +174,29 @@ using NTO = NodesToOptimize; class FuseConvActivationAction : public ReplaceWithNew { private: - std::string OpType(const RuntimeState&) const override { return "FusedConv"; } + std::string OpType(const RuntimeState& runtime_state) const override { + const auto& domain = runtime_state.selected_nodes.Target().Domain(); + const auto& op_type = runtime_state.selected_nodes.Target().OpType(); + if (domain == kOnnxDomain) { + if (op_type == "Conv") { + return "FusedConv"; + } + } else if (domain == kMSDomain) { + if (op_type == "NhwcConv") { + return "NhwcFusedConv"; + } + } else if (domain == kMSInternalNHWCDomain) { + if (op_type == "Conv") { + return "Conv"; + } + } + ORT_THROW("Unsupported operator: ", op_type, " and domain: ", domain); + } - std::string Domain(const RuntimeState&) const override { return kMSDomain; } + std::string Domain(const RuntimeState& runtime_state) const override { + auto domain = runtime_state.selected_nodes.Target().Domain(); + return domain == kOnnxDomain ? kMSDomain : domain; + } NodeAttributes ExtraAttributes(const RuntimeState& state) const override { NodeAttributes extra_fused_conv_attributes; @@ -260,8 +280,11 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { const auto name = "ConvAct"; auto action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + const std::string msInternalNHWCDomainConv = SelectorActionRegistry::OpVersionsMapKey("Conv", kMSInternalNHWCDomain); + const std::string msDomainConv = SelectorActionRegistry::OpVersionsMapKey("NhwcConv", kMSDomain); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, + + registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}, {msInternalNHWCDomainConv, {11}}, {msDomainConv, {1}}}, std::move(selector), std::move(action)); #else registry.RegisterAction(name, std::move(action)); diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 7c8bfeaec5f0f..6f90eaf07ef4d 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -287,12 +287,9 @@ class FuseConvAddActivationAction : public ReplaceWithNew { void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) { auto action = std::make_unique(); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}}, + std::string msDomainNhwcFusedConv = SelectorActionRegistry::OpVersionsMapKey("NhwcFusedConv", kMSDomain); + registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}, {msDomainNhwcFusedConv, {1, 11}}}, std::move(selector), std::move(action)); - auto action_nhwc = std::make_unique(); - auto selector_nhwc = std::make_unique(); - registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}}, - std::move(selector_nhwc), std::move(action_nhwc)); } SelectorActionRegistry CreateSelectorActionRegistry() { diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 5a441b1d1701e..86b126f2c7c31 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -270,11 +270,12 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = @@ -296,7 +297,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc index e182b6c695d2f..546d52b6f1682 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc @@ -3,9 +3,10 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" -#include #include +#include #include +#include #include #include "core/graph/op_identifier_utils.h" @@ -56,9 +57,9 @@ const SelectorActionRegistry::Entry* SelectorActionRegistry::LookUp(const std::s } #if !defined(ORT_MINIMAL_BUILD) -auto SelectorActionRegistry::LookUpByOpType(const std::string& op_type) const +auto SelectorActionRegistry::LookUpByOpTypeAndDomain(const std::string& op_type, const std::string& domain) const -> std::vector> { - const auto [range_begin, range_end] = op_type_to_entry_.equal_range(op_type); + const auto [range_begin, range_end] = op_type_to_entry_.equal_range(OpVersionsMapKey(op_type, domain)); std::vector> result{}; result.reserve(std::distance(range_begin, range_end)); std::transform(range_begin, range_end, std::back_inserter(result), @@ -93,20 +94,15 @@ static Status MatchAndProcess( Status status = Status::OK(); do { - // TODO: for now this just needs to support ONNX and Micrsoft Domain ops. - // If we ever had a transformer that was going to target non-ONNX ops, - // we'd need to rework a few things to include the op domain in the matches - if (node.Domain() != kOnnxDomain && node.Domain() != kMSDomain) { - break; - } - std::optional node_selection_opt{}; const SelectorActionRegistry::Entry* selector_action_entry_ptr = nullptr; - const auto selector_action_entries = selector_action_registry.LookUpByOpType(node.OpType()); + const auto selector_action_entries = + selector_action_registry.LookUpByOpTypeAndDomain(node.OpType(), node.Domain()); + std::string key = SelectorActionRegistry::OpVersionsMapKey(node.OpType(), node.Domain()); for (const auto& entry : selector_action_entries) { // check the supported versions if specified - const auto& versions = entry->ops_and_versions.find(node.OpType())->second; + const auto& versions = entry->ops_and_versions.find(key)->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node.SinceVersion()) == versions.cend()) { continue; diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h index 7eb162cc693f1..5caa949ebbe93 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h @@ -38,8 +38,20 @@ struct NodeSelector { // class to manage a set of selector and associated actions class SelectorActionRegistry { public: + // The key is a string representing the op, optionally specifying the domain using ':' as the + // separator with domain as the first part and operator as the second part, ":" or "". + // For ops in kOnnxDomain, the domain should be left unspecified (""). + // For ops in other domains, the domain should be specified (":"). + // Ex: "Conv", "com.microsoft:Conv", "com.ms.internal.nhwc:Conv" using OpVersionsMap = std::unordered_map>; + // Helper function to create a key to OpVersionsMap using domain and op_type. + static std::string OpVersionsMapKey(std::string_view op_type, std::string_view domain = kOnnxDomain) { + return (domain == kOnnxDomain) + ? std::string{op_type} + : std::string{domain} + ":" + std::string{op_type}; + } + struct Entry { Entry(const std::string& name_in, #if !defined(ORT_MINIMAL_BUILD) @@ -95,14 +107,15 @@ class SelectorActionRegistry { #if !defined(ORT_MINIMAL_BUILD) // return registered Entry or nullptr if not found - auto LookUpByOpType(const std::string& op_type) const -> std::vector>; + auto LookUpByOpTypeAndDomain(const std::string& op_type, + const std::string& domain) const -> std::vector>; #endif // !defined(ORT_MINIMAL_BUILD) private: std::unordered_map name_to_entry_; #if !defined(ORT_MINIMAL_BUILD) - // auxiliary mapping to enable lookup by op type + // auxiliary mapping to enable lookup by op type or "domain:op type" std::unordered_multimap op_type_to_entry_; #endif // !defined(ORT_MINIMAL_BUILD) }; diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index 2e07124dcd901..68336c996a863 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -16,6 +16,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + ONNX_OPERATOR_KERNEL_EX( Conv, kOnnxDomain, @@ -23,6 +24,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Conv, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index fdf3e5b6c6b66..3a01a4aa46be4 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -3,23 +3,42 @@ #pragma once +#include +#include + #include "core/providers/js/js_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" namespace onnxruntime { namespace js { -template -class Conv : public JsKernel { +class ConvBase : public JsKernel { public: - Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { + ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info), + conv_attrs_(info), + w_is_const_(false) { + std::vector activation_params; TensorShapeVector kernel_shape; + const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size(); + std::vector local_pads(pads_vec_size, 0); + for (size_t i = 0; i < conv_attrs_.pads.size() && i < pads_vec_size; ++i) { + local_pads[i] = gsl::narrow_cast(conv_attrs_.pads[i]); + } + if (conv_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - + if (is_fused_conv) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_attrs_.activation)); + ORT_ENFORCE(info.GetAttrs("activation_params", activation_params).IsOK()); + } else { + conv_attrs_.activation = info.GetAttrOrDefault("activation", ""); + activation_params = info.GetAttrsOrDefault("activation_params", activation_params); + } + const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr; int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); - + auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; + auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_attrs_.dilations.size() == 1 || (conv_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || @@ -30,44 +49,52 @@ class Conv : public JsKernel { "dilations" : [$2], "group" : $3, "kernel_shape" : [$4], - "pads" : [ $5, $6 ], + "pads" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : [], "strides" : [$7], - "w_is_const" : () JS_ARROW(!!HEAP8[$9]) + "w_is_const" : () JS_ARROW(!!HEAP8[$9]), + "activation" : UTF8ToString($10), + "activation_params" : $11 ? Array.from(HEAPF32.subarray($12, $12 + $11)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.group), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), - static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), - static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), + static_cast(kernel_shape_0), + static_cast(local_pads.size()), + reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str(), + activation_params.size(), + reinterpret_cast(activation_params_ptr) >> 2); } else { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ - "format" : $13 ? "NHWC" : "NCHW", + "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, "dilations" : [ $2, $3 ], "group" : $4, "kernel_shape" : [ $5, $6 ], - "pads" : [ $7, $8, $9, $10 ], - "strides" : [ $11, $12 ], - "w_is_const" : () JS_ARROW(!!HEAP8[$14]) + "pads" : $7 ? Array.from(HEAP32.subarray($8, $8 + $7)) : [], + "strides" : [ $9, $10 ], + "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "activation" : UTF8ToString($13), + "activation_params" : $14 ? Array.from(HEAPF32.subarray($15, $15 + $14)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.dilations.size() > 1 ? conv_attrs_.dilations[1] : 0), static_cast(conv_attrs_.group), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0), - static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), - static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), - static_cast(conv_attrs_.pads.size() > 2 ? conv_attrs_.pads[2] : 0), - static_cast(conv_attrs_.pads.size() > 3 ? conv_attrs_.pads[3] : 0), + static_cast(kernel_shape_0), + static_cast(kernel_shape_1), + static_cast(local_pads.size()), + reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str(), + activation_params.size(), + reinterpret_cast(activation_params_ptr) >> 2); } } @@ -94,5 +121,12 @@ class Conv : public JsKernel { // Tensor w_transposed_; }; +template +class Conv : public ConvBase { + public: + explicit Conv(const OpKernelInfo& info) : ConvBase(info, is_channels_last, is_fused_conv) { + } +}; + } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 2228343e1e6e3..f7f0ab22b7006 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -15,6 +15,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + ONNX_OPERATOR_KERNEL_EX( ConvTranspose, kOnnxDomain, @@ -22,6 +23,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( ConvTranspose, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 18ef73268005d..5d30dc851e00f 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -4,26 +4,45 @@ #pragma once #include +#include #include "core/common/gsl.h" #include "core/providers/cpu/nn/conv_transpose_attributes.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { TensorShapeVector kernel_shape; + if (is_fused_convtranspose) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_transpose_attrs_.activation)); + } else { + conv_transpose_attrs_.activation = info.GetAttrOrDefault("activation", ""); + } + if (conv_transpose_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); + std::vector local_output_shape(conv_transpose_attrs_.output_shape.begin(), + conv_transpose_attrs_.output_shape.end()); + std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), + conv_transpose_attrs_.output_padding.end()); + const auto* local_output_padding_ptr = + local_output_padding.size() > 0 ? local_output_padding.data() : nullptr; + const auto* local_output_shape_ptr = + local_output_shape.size() > 0 ? local_output_shape.data() : nullptr; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_transpose_attrs_.dilations.size() == 1 || (conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || conv_transpose_attrs_.strides.size() == 1) { + auto dilations = conv_transpose_attrs_.dilations.size() > 0 ? conv_transpose_attrs_.dilations[0] : 0; + auto kernel_shape_0 = conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; + auto pads_0 = conv_transpose_attrs_.pads.size() > 0 ? conv_transpose_attrs_.pads[0] : 0; + auto pads_1 = conv_transpose_attrs_.pads.size() > 1 ? conv_transpose_attrs_.pads[1] : 0; + auto strides = conv_transpose_attrs_.strides.size() > 0 ? conv_transpose_attrs_.strides[0] : 0; JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $8 ? "NHWC" : "NCHW", "autoPad" : $1, @@ -34,21 +53,23 @@ class ConvTranspose : public JsKernel { "strides" : [$7], "wIsConst" : () JS_ARROW(!!HEAP8[$9]), "outputPadding" : $10 ? Array.from(HEAP32.subarray($11, $11 + $10)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [] + "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [], + "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), - static_cast(conv_transpose_attrs_.dilations.size() > 0 ? conv_transpose_attrs_.dilations[0] : 0), + static_cast(dilations), static_cast(conv_transpose_attrs_.group), - static_cast(conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() > 0) ? kernel_shape[0] : 0, - static_cast(conv_transpose_attrs_.pads.size()), - static_cast(conv_transpose_attrs_.pads.size() > 1) ? conv_transpose_attrs_.pads[1] : 0, - static_cast(conv_transpose_attrs_.strides.size() > 0) ? conv_transpose_attrs_.strides[0] : 0, + static_cast(kernel_shape_0), + static_cast(pads_0), + static_cast(pads_1), + static_cast(strides), static_cast(channels_last), reinterpret_cast(&w_is_const_), - gsl::narrow_cast(conv_transpose_attrs_.output_shape.size()), - reinterpret_cast(conv_transpose_attrs_.output_padding.size() > 0 ? conv_transpose_attrs_.output_padding.data() : nullptr) >> 2, - gsl::narrow_cast(conv_transpose_attrs_.output_shape.size()), - reinterpret_cast(conv_transpose_attrs_.output_shape.size() > 0 ? conv_transpose_attrs_.output_shape.data() : nullptr) >> 2); + gsl::narrow_cast(local_output_padding.size()), + reinterpret_cast(local_output_padding_ptr) >> 2, + gsl::narrow_cast(local_output_shape.size()), + reinterpret_cast(local_output_shape_ptr) >> 2, + conv_transpose_attrs_.activation.c_str()); } else { constexpr size_t pads_vec_size = 4; constexpr size_t strides_vec_size = 2; @@ -59,8 +80,6 @@ class ConvTranspose : public JsKernel { std::vector local_strides(strides_vec_size, 0); std::vector local_dilations(dialations_vec_size, 0); std::vector local_kernel_shape; - std::vector local_output_shape(conv_transpose_attrs_.output_shape.begin(), conv_transpose_attrs_.output_shape.end()); - std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), conv_transpose_attrs_.output_padding.end()); if (conv_transpose_attrs_.kernel_shape_specified) { for (size_t i = 0; i < kernel_shape.size() && i < kernel_shape_vec_size; ++i) { local_kernel_shape.push_back(gsl::narrow_cast(kernel_shape[i])); @@ -91,7 +110,8 @@ class ConvTranspose : public JsKernel { "strides" : Array.from(HEAP32.subarray($6, $6 + /* strides_vec_size */ 2)), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), "outputPadding" : ($9 > 0) ? Array.from(HEAP32.subarray($10, $10 + $9)) : [], - "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [] + "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [], + "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), reinterpret_cast(local_dilations.data()) >> 2, @@ -102,9 +122,10 @@ class ConvTranspose : public JsKernel { static_cast(channels_last), reinterpret_cast(&w_is_const_), gsl::narrow_cast(local_output_padding.size()), - reinterpret_cast(local_output_padding.size() > 0 ? local_output_padding.data() : nullptr) >> 2, + reinterpret_cast(local_output_padding_ptr) >> 2, gsl::narrow_cast(local_output_shape.size()), - reinterpret_cast(local_output_shape.size() > 0 ? local_output_shape.data() : nullptr) >> 2); + reinterpret_cast(local_output_shape_ptr) >> 2, + conv_transpose_attrs_.activation.c_str()); } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 46b95a127b75c..a6aa4b946f397 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1438,7 +1438,7 @@ TEST_F(GraphTransformationTests, NotWhereFusion) { ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where } -#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS) +#if (defined(USE_CUDA) || defined(USE_JSEP)) && !defined(DISABLE_CONTRIB_OPS) // Conv->Add->Relu will be transformed to FusedConv TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx"; @@ -1618,6 +1618,10 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCudaExecutionProvider); } +#elif defined(USE_JSEP) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kJsExecutionProvider); + } #endif std::map op_to_count_before_fusion = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count_before_fusion[model.second] >= 1); @@ -1632,6 +1636,13 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { std::set cuda_rocm_supported = {"Relu"}; if (cuda_rocm_supported.find(model.second) == cuda_rocm_supported.end()) { ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); + } else { + ASSERT_EQ(op_to_count_after_fusion[model.second], 0); + } +#elif defined(USE_JSEP) + std::set js_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu"}; + if (js_supported.find(model.second) == js_supported.end()) { + ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); } else { ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); }